From 69f5add84c8494e407c819b80685593483c97d32 Mon Sep 17 00:00:00 2001 From: William Zhang <133824995+2ez4bz@users.noreply.github.com> Date: Wed, 10 Jun 2026 04:32:49 -0700 Subject: [PATCH 1/6] [https://nvbugs/6272573][ci] Unwaive skipped test (#15118) Also adds another test shard for B200 since we were hitting slurm timeouts. Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> --- jenkins/L0_Test.groovy | 16 +++++++++------- tests/integration/test_lists/waives.txt | 1 - 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/jenkins/L0_Test.groovy b/jenkins/L0_Test.groovy index 590d5b27171e..fbdbd4df3091 100644 --- a/jenkins/L0_Test.groovy +++ b/jenkins/L0_Test.groovy @@ -4092,13 +4092,15 @@ def launchTestJobs(pipeline, testFilter) "DGX_H100-4_GPUs-PyTorch-Ray-1": ["auto:dgx-h100-x4", "l0_dgx_h100", 1, 1, 4], "DGX_H100-4_GPUs-AutoDeploy-1": ["auto:dgx-h100-x4", "l0_dgx_h100", 1, 1, 4], "DGX_H100-4_GPUs-AutoDeploy-Post-Merge-1": ["auto:dgx-h100-x4", "l0_dgx_h100", 1, 1, 4], - "DGX_B200-PyTorch-1": ["auto:dgx-b200-flex", "l0_b200", 1, 7, 1, 1, true], - "DGX_B200-PyTorch-2": ["auto:dgx-b200-flex", "l0_b200", 2, 7, 1, 1, true], - "DGX_B200-PyTorch-3": ["auto:dgx-b200-flex", "l0_b200", 3, 7, 1, 1, true], - "DGX_B200-PyTorch-4": ["auto:dgx-b200-flex", "l0_b200", 4, 7, 1, 1, true], - "DGX_B200-PyTorch-5": ["auto:dgx-b200-flex", "l0_b200", 5, 7, 1, 1, true], - "DGX_B200-PyTorch-6": ["auto:dgx-b200-flex", "l0_b200", 6, 7, 1, 1, true], - "DGX_B200-PyTorch-7": ["auto:dgx-b200-flex", "l0_b200", 7, 7, 1, 1, true], + "DGX_B200-PyTorch-1": ["auto:dgx-b200-flex", "l0_b200", 1, 9, 1, 1, true], + "DGX_B200-PyTorch-2": ["auto:dgx-b200-flex", "l0_b200", 2, 9, 1, 1, true], + "DGX_B200-PyTorch-3": ["auto:dgx-b200-flex", "l0_b200", 3, 9, 1, 1, true], + "DGX_B200-PyTorch-4": ["auto:dgx-b200-flex", "l0_b200", 4, 9, 1, 1, true], + "DGX_B200-PyTorch-5": ["auto:dgx-b200-flex", "l0_b200", 5, 9, 1, 1, true], + "DGX_B200-PyTorch-6": ["auto:dgx-b200-flex", "l0_b200", 6, 9, 1, 1, true], + "DGX_B200-PyTorch-7": ["auto:dgx-b200-flex", "l0_b200", 7, 9, 1, 1, true], + "DGX_B200-PyTorch-8": ["auto:dgx-b200-flex", "l0_b200", 8, 9, 1, 1, true], + "DGX_B200-PyTorch-9": ["auto:dgx-b200-flex", "l0_b200", 9, 9, 1, 1, true], "DGX_B200-AutoDeploy-1": ["auto:dgx-b200-flex", "l0_b200", 1, 1, 1, 1, true], "DGX_B200-Triton-Post-Merge-1": ["auto:dgx-b200-flex", "l0_b200", 1, 1, 1, 1, true], "DGX_B200-PyTorch-Post-Merge-1": ["auto:dgx-b200-flex", "l0_b200", 1, 2, 1, 1, true], diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index c7ba7180d06c..e869f5202919 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -394,7 +394,6 @@ unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_t unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens16-_hidden512] SKIP (https://nvbugs/6266259) unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden32] SKIP (https://nvbugs/6266259) unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_tokens256-_hidden512] SKIP (https://nvbugs/6266259) -unittest/_torch/multimodal/test_mm_encoder_standalone.py::test_single_request_chat_multiple_images[pd_disagg-qwen3_30b_a3b_fp8] SKIP (https://nvbugs/6272573) unittest/_torch/thop/serial/test_moe.py::TestMoeFp4::test_no_autotune[use_score_as_input-RoutingDSv3-swiglu-1024-1024-1] SKIP (https://nvbugs/5908070) unittest/_torch/thop/serial/test_moe.py::TestMoeFp4::test_no_autotune[use_score_as_input-RoutingRenormalize_qwen_next-swiglu-1024-1024-150] SKIP (https://nvbugs/5908070) unittest/_torch/thop/serial/test_moe.py::TestMoeFp4::test_no_autotune[use_score_as_input-RoutingRenormalize_topk_4-swiglu-1024-1024-150] SKIP (https://nvbugs/5908070) From 74d8a484ed1c5d624e11ded8f10cf0dcef2dd953 Mon Sep 17 00:00:00 2001 From: Gal Hubara-Agam <96368689+galagam@users.noreply.github.com> Date: Wed, 10 Jun 2026 17:06:10 +0300 Subject: [PATCH 2/6] [https://nvbugs/6245279][fix] AutoDeploy: Unwaive accuracy tests (#15214) Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> --- tests/integration/test_lists/waives.txt | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index e869f5202919..0baf89ec459f 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -9,11 +9,6 @@ accuracy/test_dwdp_disaggregated_serving.py::TestDwdpDeepSeekV3Lite::test_dwdp_a accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_gather_generation_logits_cuda_graph SKIP (https://nvbugs/5772995) accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar] SKIP (https://nvbugs/5346443) accuracy/test_llm_api.py::TestMistralNemo12B::test_fp8 SKIP (https://nvbugs/5413197) -accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[deepseek-ai_DeepSeek-R1-0528-True] SKIP (https://nvbugs/6278380) -accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[nvidia_Llama-3.1-8B-Instruct-NVFP4-True] SKIP (https://nvbugs/6245279) -accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-8] SKIP (https://nvbugs/6248757) -accuracy/test_llm_api_autodeploy.py::TestNemotronV2::test_fp8[True] SKIP (https://nvbugs/6261164) -accuracy/test_llm_api_autodeploy.py::TestQwen3_5_397B_MoE::test_nvfp4[8] SKIP (https://nvbugs/6278380) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp_trtllm] SKIP (https://nvbugs/6191524) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] SKIP (https://nvbugs/6084775) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] SKIP (https://nvbugs/6029882) From 6db3233db8f7dd13a7f35f8a5ccbc909a7f601c7 Mon Sep 17 00:00:00 2001 From: Zhenhua Wang <4936589+zhenhuaw-me@users.noreply.github.com> Date: Wed, 10 Jun 2026 23:45:05 +0800 Subject: [PATCH 3/6] [TRTLLM-12491][feat] Align VisualGen serve request schema with VisualGenParams (#14733) Signed-off-by: Zhenhua Wang --- examples/visual_gen/serve/README.md | 56 +- examples/visual_gen/serve/async_video_gen.py | 32 +- examples/visual_gen/serve/sync_image_gen.py | 43 +- examples/visual_gen/serve/sync_video_gen.py | 31 +- tensorrt_llm/_torch/visual_gen/executor.py | 626 +++++++-- .../models/cosmos3/pipeline_cosmos3.py | 2 +- .../visual_gen/models/flux/pipeline_flux.py | 2 +- .../visual_gen/models/flux/pipeline_flux2.py | 2 +- .../visual_gen/models/ltx2/pipeline_ltx2.py | 12 +- .../models/ltx2/pipeline_ltx2_two_stages.py | 4 +- .../_torch/visual_gen/models/wan/defaults.py | 5 - .../visual_gen/models/wan/pipeline_wan.py | 2 +- .../visual_gen/models/wan/pipeline_wan_i2v.py | 3 +- tensorrt_llm/_torch/visual_gen/output.py | 10 +- tensorrt_llm/media/encoding.py | 30 +- tensorrt_llm/media/tensor_payload.py | 246 ++++ tensorrt_llm/serve/openai_protocol.py | 292 +++-- tensorrt_llm/serve/openai_server.py | 220 +++- tensorrt_llm/serve/openai_video_routes.py | 330 +++-- tensorrt_llm/serve/visual_gen_utils.py | 178 ++- tensorrt_llm/visual_gen/output.py | 156 ++- tensorrt_llm/visual_gen/params.py | 125 +- tensorrt_llm/visual_gen/visual_gen.py | 555 +------- .../integration/test_lists/test-db/l0_a10.yml | 1 + .../test_lists/test-db/l0_b200.yml | 1 + .../multi_gpu/test_visual_gen_multinode.py | 12 +- .../_torch/visual_gen/test_tensor_payload.py | 331 +++++ .../visual_gen/test_trtllm_serve_e2e.py | 24 +- .../visual_gen/test_trtllm_serve_endpoints.py | 1142 +++++++++++++++-- .../unittest/_torch/visual_gen/test_utils.py | 22 +- .../_torch/visual_gen/test_visual_gen_args.py | 17 +- .../visual_gen/test_visual_gen_params.py | 309 ++++- .../visual_gen/test_visual_gen_utils.py | 343 +++++ tests/unittest/media/test_encoding.py | 19 +- 34 files changed, 4010 insertions(+), 1173 deletions(-) create mode 100644 tensorrt_llm/media/tensor_payload.py create mode 100644 tests/unittest/_torch/visual_gen/test_tensor_payload.py create mode 100644 tests/unittest/_torch/visual_gen/test_visual_gen_utils.py diff --git a/examples/visual_gen/serve/README.md b/examples/visual_gen/serve/README.md index d76ece972b07..a979767b2a6a 100644 --- a/examples/visual_gen/serve/README.md +++ b/examples/visual_gen/serve/README.md @@ -67,7 +67,7 @@ Demonstrates synchronous text-to-image generation using the OpenAI SDK. Supports **Features:** - Generates images from text prompts -- Supports configurable model, image size, and quality +- Supports configurable model and image size - Returns base64-encoded images or URLs - Saves generated images to disk @@ -269,20 +269,52 @@ You can customize these by: ## Common Parameters ### Image Generation -- `model`: Model identifier (e.g., "flux1", "flux2") -- `prompt`: Text description +- `prompt`: Text description (required) - `n`: Number of images to generate -- `size`: Image dimensions (e.g., "512x512", "1024x1024") -- `quality`: "standard" or "hd" -- `response_format`: "b64_json" or "url" +- `size`: Image dimensions in `WxH` format (e.g., `"512x512"`, `"1024x1024"`) — or use the structured pair `width` + `height` (both required when sent) +- `seed`: Random seed; `null` / omitted means the engine draws a fresh seed +- `num_inference_steps`, `guidance_scale`, `max_sequence_length`, `negative_prompt`: per-request denoise controls (override pipeline defaults when sent) +- `extra_params`: model-specific overflow as a JSON object (see "Model-Specific `extra_params`" below). Unknown keys are rejected by the executor. +- `response_format`: `"b64_json"` or `"url"` +- `format`: Generation content encoding. Image encoders: `"png"`, `"webp"`, `"jpeg"`. Tensor formats: `"safetensors"`, `"pt"`. +- Accept-and-warn OpenAI-shape fields (no engine semantic): `model`, `quality`, `style`, `user`. Sending `quality`/`style` logs a server-side WARNING; sending `model` warns on mismatch. None of these change generation behavior. ### Video Generation -- `model`: Model identifier (e.g., "wan", "ltx2") -- `prompt`: Text description -- `size`: Video resolution (e.g., "256x256", "512x512", "1280x720") -- `seconds`: Duration in seconds -- `fps`: Frames per second -- `input_reference`: Reference image file (for TI2V mode) +- `prompt`: Text description (required) +- `size` / `width` / `height`: same convention as image +- `seconds`: Duration in seconds (engine multiplies by `frame_rate` to derive `num_frames` when the latter is absent) +- `frame_rate` (canonical) or `fps` (alias): frames per second +- `num_frames`: when set, wins over the `seconds * frame_rate` derivation +- `seed`, `num_inference_steps`, `guidance_scale`, `max_sequence_length`, `negative_prompt`: per-request denoise controls +- `input_reference`: Reference image (TI2V mode); accepted as base64-encoded string in JSON or as a file in multipart form-data +- `extra_params`: model-specific overflow (see below) +- `response_format`: `"b64_json"` or `"url"` +- `format`: Generation content encoding. Video encoders: `"mp4"`, `"avi"`, `"auto"`. Tensor formats: `"safetensors"`, `"pt"` (carries video + audio + scalar metadata in one payload for LTX-2). + +#### Tensor-format consumer contract + +When `format="safetensors"` or `format="pt"`, the payload bundles every populated media tensor (`image` / `video` / `audio`) and the scalar metadata (`frame_rate`, `audio_sample_rate`) into one file. + +- **`pt`**: `torch.load(buf, weights_only=True)` returns a dict with the tensor keys and the scalars as native Python values. +- **`safetensors`**: `safetensors.torch.load(bytes)` returns a dict with the tensor keys and each scalar as a 0-d tensor under the same key — call `.item()` to unbox (e.g. `loaded["frame_rate"].item()`). The same scalars are also written to the safetensors file header as strings; `safe_open(path, framework="pt").metadata()` exposes them in that form for consumers that prefer header access. + +#### Unknown-field policy + +The visual-gen endpoints reject unknown top-level fields with HTTP 422 (`extra="forbid"`). Anything model-specific belongs inside `extra_params`. Sending `output_format`, top-level `guidance_rescale`, or — for video — top-level `n` returns 422 with the offending field named in the error body. + +#### Model-specific `extra_params` + +Use the Python API to discover accepted keys for a loaded pipeline: + +```python +generator = VisualGen(model="...") +print(generator.extra_param_specs) # {key: ExtraParamSchema(type=..., range=..., default=..., description=...)} +``` + +Examples: +- **LTX-2**: `stg_scale`, `stg_blocks`, `modality_scale`, `guidance_rescale`, `output_type`, ... +- **Wan 2.2 A14B**: `guidance_scale_2`, `boundary_ratio` +- **Wan 2.1 / Flux**: no model-specific `extra_params` declared > **Note:** LTX-2 generates video **with audio**. The `ltx2.yml` config must include > `text_encoder_path` pointing to a Gemma3 model (e.g., `google/gemma-3-12b-it`). diff --git a/examples/visual_gen/serve/async_video_gen.py b/examples/visual_gen/serve/async_video_gen.py index 42e64594c99f..b884de0bbd22 100755 --- a/examples/visual_gen/serve/async_video_gen.py +++ b/examples/visual_gen/serve/async_video_gen.py @@ -33,7 +33,7 @@ def test_async_video_generation( fps: int = 24, size: str = "256x256", output_file: str = "output_async.mp4", - output_format: str = "auto", + format: str = "auto", ): """Test asynchronous video generation with OpenAI SDK. @@ -77,7 +77,7 @@ def test_async_video_generation( "seconds": duration, "extra_body": { "fps": fps, - "output_format": output_format, + "format": format, }, } @@ -131,12 +131,18 @@ def test_async_video_generation( # For binary content, use the underlying HTTP client content = client.videos.download_content(video_id, variant="video") - # Check content type to determine actual file extension - content_type = getattr(content.response, "headers", {}).get("content-type", "video/mp4") - if "x-msvideo" in content_type or "avi" in content_type: - actual_ext = ".avi" + # Determine the on-disk extension. Tensor formats are + # selected by the request and the server returns + # ``application/octet-stream``; encoder formats can be + # disambiguated from Content-Type (mp4 vs avi). + if format in ("safetensors", "pt"): + actual_ext = f".{format}" else: - actual_ext = ".mp4" + content_type = getattr(content.response, "headers", {}).get("content-type", "video/mp4") + if "x-msvideo" in content_type or "avi" in content_type: + actual_ext = ".avi" + else: + actual_ext = ".mp4" # Adjust output filename if extension doesn't match output_path = Path(output_file) @@ -233,11 +239,15 @@ def test_async_video_generation( ) parser.add_argument( - "--output-format", + "--format", type=str, default="auto", - choices=["mp4", "avi", "auto"], - help="Output video format: mp4 or avi or auto", + choices=["mp4", "avi", "auto", "safetensors", "pt"], + help=( + "Generation content encoding format. Encoders: mp4 / avi / auto. " + "Tensor formats safetensors / pt return raw tensor bytes for " + "programmatic post-processing." + ), ) args = parser.parse_args() @@ -264,7 +274,7 @@ def test_async_video_generation( fps=args.fps, size=args.size, output_file=args.output, - output_format=args.output_format, + format=args.format, ) sys.exit(0 if success else 1) diff --git a/examples/visual_gen/serve/sync_image_gen.py b/examples/visual_gen/serve/sync_image_gen.py index 9f9c971dc7d9..ef3fa58453ce 100755 --- a/examples/visual_gen/serve/sync_image_gen.py +++ b/examples/visual_gen/serve/sync_image_gen.py @@ -28,11 +28,16 @@ def test_image_generation( prompt: str = "A lovely cat lying on a sofa", n: int = 1, size: str = "512x512", - quality: str = "standard", + format: str = "png", response_format: str = "b64_json", output_file: str = "output_generation.png", ): - """Test image generation endpoint.""" + """Test image generation endpoint. + + ``format`` selects the encoding for the returned bytes. Image encoders + are ``"png"``, ``"webp"``, ``"jpeg"``; tensor payloads are + ``"safetensors"`` and ``"pt"``. + """ print("=" * 80) print("Testing Image Generation API (POST /v1/images/generations)") print("=" * 80) @@ -44,30 +49,41 @@ def test_image_generation( print(f" Model: {model}") print(f" Prompt: {prompt}") print(f" Size: {size}") - print(f" Quality: {quality}") + print(f" Format: {format}") print(f" Number of images: {n}") try: - # Use OpenAI SDK's images.generate() method + # ``format`` is a trtllm-serve extension over the OpenAI image + # API; the SDK forwards it via ``extra_body``. response = client.images.generate( model=model, prompt=prompt, n=n, size=size, - quality=quality, response_format=response_format, + extra_body={"format": format}, ) print("\n✓ Image generated successfully!") print(f" Number of images: {len(response.data)}") + # Choose the on-disk extension to match the requested format so + # the saved file's suffix reflects its actual contents. + ext_map = { + "png": ".png", + "webp": ".webp", + "jpeg": ".jpeg", + "safetensors": ".safetensors", + "pt": ".pt", + } + ext = ext_map[format] + stem = output_file.rsplit(".", 1)[0] + # Save images for i, image in enumerate(response.data): if response_format == "b64_json": - # Decode base64 image image_data = base64.b64decode(image.b64_json) - output = f"{output_file.rsplit('.', 1)[0]}_{i}.png" if n > 1 else output_file - + output = f"{stem}_{i}{ext}" if n > 1 else f"{stem}{ext}" with open(output, "wb") as f: f.write(image_data) @@ -116,6 +132,16 @@ def test_image_generation( default="512x512", help="Image size in WxH format (e.g., 512x512, 1024x1024)", ) + parser.add_argument( + "--format", + type=str, + default="png", + choices=["png", "webp", "jpeg", "safetensors", "pt"], + help=( + "Generation content encoding format. Image encoders: png / " + "webp / jpeg. Tensor payloads: safetensors / pt." + ), + ) parser.add_argument( "--output", type=str, @@ -137,6 +163,7 @@ def test_image_generation( model=args.model, prompt=args.prompt, size=args.size, + format=args.format, output_file=args.output, ) diff --git a/examples/visual_gen/serve/sync_video_gen.py b/examples/visual_gen/serve/sync_video_gen.py index 5842c3dc99de..4de8ee5072f8 100755 --- a/examples/visual_gen/serve/sync_video_gen.py +++ b/examples/visual_gen/serve/sync_video_gen.py @@ -32,6 +32,7 @@ def test_sync_video_generation( fps: int = 24, size: str = "256x256", output_file: str = "output_sync.mp4", + format: str = "auto", ): """Test synchronous video generation with direct HTTP requests. @@ -79,6 +80,7 @@ def test_sync_video_generation( "size": size, "seconds": str(duration), "fps": str(fps), + "format": format, } # Add the file @@ -103,18 +105,25 @@ def test_sync_video_generation( "size": size, "seconds": duration, "fps": fps, + "format": format, }, ) print(f"\nStatus code: {response_video.status_code}") if response_video.status_code == 200: - # Determine actual file extension from Content-Type header - content_type = response_video.headers.get("content-type", "video/mp4") - if "x-msvideo" in content_type or "avi" in content_type: - actual_ext = ".avi" + # Determine the on-disk extension. Tensor formats are + # selected by the request and the server returns + # ``application/octet-stream``; encoder formats can be + # disambiguated from Content-Type (mp4 vs avi). + if format in ("safetensors", "pt"): + actual_ext = f".{format}" else: - actual_ext = ".mp4" + content_type = response_video.headers.get("content-type", "video/mp4") + if "x-msvideo" in content_type or "avi" in content_type: + actual_ext = ".avi" + else: + actual_ext = ".mp4" # Adjust output filename if extension doesn't match output_path = Path(output_file) @@ -214,6 +223,17 @@ def test_sync_video_generation( default="output_sync.mp4", help="Output video file path (extension may change based on server encoder: .mp4 or .avi)", ) + parser.add_argument( + "--format", + type=str, + default="auto", + choices=["mp4", "avi", "auto", "safetensors", "pt"], + help=( + "Generation content encoding format. Video encoders: mp4 / " + "avi / auto. Tensor payloads: safetensors / pt carry video + " + "audio + scalar metadata in a single file." + ), + ) args = parser.parse_args() @@ -239,6 +259,7 @@ def test_sync_video_generation( fps=args.fps, size=args.size, output_file=args.output, + format=args.format, ) sys.exit(0 if success else 1) diff --git a/tensorrt_llm/_torch/visual_gen/executor.py b/tensorrt_llm/_torch/visual_gen/executor.py index 8b4f2dd876d2..10ae907905bd 100644 --- a/tensorrt_llm/_torch/visual_gen/executor.py +++ b/tensorrt_llm/_torch/visual_gen/executor.py @@ -1,13 +1,16 @@ +import asyncio import os import queue +import socket import threading import time import traceback from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union import torch import torch.distributed as dist +import torch.multiprocessing as mp import zmq from tensorrt_llm._torch.visual_gen.output import PipelineOutput @@ -19,6 +22,74 @@ if TYPE_CHECKING: from tensorrt_llm.visual_gen.params import VisualGenParams +# Timeouts (seconds) for the client-side coordinator. +POLL_TIMEOUT = 0.01 +AWAIT_TIMEOUT = 0.05 +THREAD_TIMEOUT = 5.0 +WORKER_TIMEOUT = 2.0 + + +def find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def get_ip_address() -> str: + """Get local IP address.""" + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("10.255.255.255", 1)) + return s.getsockname()[0] + except Exception: + return "127.0.0.1" + finally: + s.close() + + +def _detect_external_launch() -> Optional[Tuple[int, int, int, str, int]]: + """Detect whether the process was launched by an external distributed launcher. + + Checks for torchrun (``RANK`` + ``WORLD_SIZE``) and then SLURM + (``SLURM_PROCID`` + ``SLURM_NTASKS``). Returns a + ``(rank, local_rank, world_size, master_addr, master_port)`` tuple when a + multi-process launcher is detected (world_size > 1), or ``None`` for + single-process / single-node ``mp.Process`` mode. + """ + # torchrun / torchelastic sets RANK and WORLD_SIZE + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + if world_size > 1: + local_rank = int(os.environ.get("LOCAL_RANK", rank)) + master_addr = os.environ.get("MASTER_ADDR") + if master_addr is None: + raise RuntimeError( + "MASTER_ADDR must be set for multi-node torchrun runs. " + "Add --master-addr= to your torchrun command, or set " + "MASTER_ADDR in the environment before launching." + ) + master_port = int(os.environ.get("MASTER_PORT", 29500)) + return rank, local_rank, world_size, master_addr, master_port + + # SLURM: srun --ntasks-per-node=GPUS_PER_NODE sets SLURM_PROCID / SLURM_NTASKS + if "SLURM_PROCID" in os.environ and "SLURM_NTASKS" in os.environ: + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NTASKS"]) + if world_size > 1: + local_rank = int(os.environ.get("SLURM_LOCALID", rank)) + master_addr = os.environ.get("MASTER_ADDR") + if master_addr is None: + raise RuntimeError( + "MASTER_ADDR must be set for multi-node SLURM runs. " + "Add to your sbatch script:\n" + " MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -1)" + ) + master_port = int(os.environ.get("MASTER_PORT", 29500)) + return rank, local_rank, world_size, master_addr, master_port + + return None + @dataclass class DiffusionRequest: @@ -59,33 +130,6 @@ class DiffusionResponse: generation: float = 0.0 -# Python type name → accepted Python types for ExtraParamSchema validation. -_TYPE_MAP = { - "float": (float, int), - "int": (int,), - "bool": (bool,), - "str": (str,), - "list": (list,), -} - -# Generation config fields that pipelines declare defaults for. -# If a user sets one of these but the pipeline doesn't declare it in -# default_generation_params, the value will be silently ignored. -# Conditioning inputs (image, negative_prompt, mask, image_cond_strength) -# are excluded — they are validated at runtime by the pipeline's infer(). -_GENERATION_CONFIG_FIELDS = frozenset( - { - "height", - "width", - "num_inference_steps", - "guidance_scale", - "max_sequence_length", - "num_frames", - "frame_rate", - } -) - - class DiffusionExecutor: """Execution engine for diffusion models running in worker processes.""" @@ -192,7 +236,10 @@ def serve_forever(self): req = self.requests_ipc.get() logger.info(f"Worker {self.device_id}: Request available") - # Broadcast to all ranks + # Broadcast to all ranks. ``req.params.seed`` is already a + # concrete int — resolved once on the coordinator process at + # :meth:`VisualGen.generate_async` entry — so the broadcast + # propagates the same value to every rank. obj_list = [req] dist.broadcast_object_list(obj_list, src=0) req = obj_list[0] @@ -211,18 +258,10 @@ def _merge_defaults(self, req: DiffusionRequest): """Fill ``None`` fields in *req.params* with pipeline-specific defaults. Merges both universal defaults (from ``default_generation_params``) - and extra_param defaults (from ``extra_param_specs``). + and extra_param defaults (from ``extra_param_specs``). ``req.params`` + is expected to be a concrete :class:`VisualGenParams`; defaults are + materialized at the :class:`VisualGen.generate_async` enqueue site. """ - if req.params is None: - from tensorrt_llm.visual_gen.params import VisualGenParams - - kwargs = dict(self.pipeline.default_generation_params) - specs = self.pipeline.extra_param_specs - if specs: - kwargs["extra_params"] = {key: spec.default for key, spec in specs.items()} - req.params = VisualGenParams(**kwargs) - return - params = req.params # Universal field defaults for field_name, default_value in self.pipeline.default_generation_params.items(): @@ -238,78 +277,6 @@ def _merge_defaults(self, req: DiffusionRequest): if key not in params.extra_params: params.extra_params[key] = spec.default - self._validate_request(req) - - def _validate_request(self, req: DiffusionRequest): - """Validate *req.params* against the loaded pipeline's declared parameters. - - Raises ``ValueError`` on: - - Unknown ``extra_params`` keys - - Universal fields (e.g. ``num_frames``) set by the user but not - declared in the pipeline's ``default_generation_params`` - - Type mismatches for ``extra_params`` values - - Out-of-range ``extra_params`` values - """ - params = req.params - errors: list[str] = [] - pipeline_name = self.pipeline.__class__.__name__ - declared_defaults = self.pipeline.default_generation_params - specs = self.pipeline.extra_param_specs - - # --- unknown extra_params keys --- - if params.extra_params: - unknown = set(params.extra_params.keys()) - set(specs.keys()) - if unknown: - errors.append( - f"Unknown extra_params {sorted(unknown)} for {pipeline_name}. " - f"Supported: {sorted(specs.keys())}" - ) - - # --- unsupported universal fields --- - # Check generation config fields the user explicitly set (not None) - # that the pipeline never declared in default_generation_params. - # Conditioning inputs (image, negative_prompt, mask) are excluded — - # they are validated at runtime by the pipeline's infer(). - for field_name in _GENERATION_CONFIG_FIELDS: - value = getattr(params, field_name, None) - if value is not None and field_name not in declared_defaults: - errors.append( - f"Parameter '{field_name}' is set but {pipeline_name} does " - f"not use it (not in default_generation_params). " - f"It will be silently ignored." - ) - - # --- extra_params type and range checks --- - if params.extra_params: - for key, value in params.extra_params.items(): - if key not in specs: - continue # already reported as unknown above - spec = specs[key] - # Skip None values (param left at its None default) - if value is None: - continue - # Type check - expected_types = _TYPE_MAP.get(spec.type) - if expected_types and not isinstance(value, expected_types): - errors.append( - f"extra_params['{key}'] expected type '{spec.type}', " - f"got {type(value).__name__}: {value!r}" - ) - continue # skip range check if type is wrong - # Range check (numeric only) - if spec.range is not None and isinstance(value, (int, float)): - lo, hi = spec.range - if not (lo <= value <= hi): - errors.append( - f"extra_params['{key}'] value {value} is out of range [{lo}, {hi}]" - ) - - if errors: - msg = f"Parameter validation failed for {pipeline_name}:\n" + "\n".join( - f" - {e}" for e in errors - ) - raise ValueError(msg) - def process_request(self, req: DiffusionRequest): """Process a single request.""" try: @@ -411,3 +378,438 @@ def run_diffusion_worker( except Exception as e: logger.error(f"Worker failed: {e}") traceback.print_exc() + + +class DiffusionRemoteClient: + """Client proxy for remote DiffusionExecutor in worker processes. + + Internal coordinator-side counterpart to :class:`DiffusionExecutor`. Not + part of the public ``tensorrt_llm.visual_gen`` API; the user-facing + entry point is :class:`tensorrt_llm.visual_gen.VisualGen`, which resolves + every request's seed before reaching :meth:`enqueue_requests`. + + Supports two launch modes: + + **Single-node (default)** + ``VisualGen`` is called from an ordinary Python script. + ``DiffusionRemoteClient`` spawns all worker processes locally via + ``mp.Process`` with ``master_addr=127.0.0.1``. + + **Multi-node (external launcher)** + The script is launched by ``torchrun`` or ``srun --ntasks-per-node=GPUS``. + Each rank runs the same script; ``RANK`` / ``WORLD_SIZE`` / ``MASTER_ADDR`` + / ``MASTER_PORT`` are already set in the environment. + + - Rank 0: becomes the request coordinator. It creates the ZMQ server + sockets and starts its own worker in a background thread, then returns + to the caller so the user script can call ``generate()``. + - Rank > 0: handled by ``VisualGen.__init__`` before this class is + instantiated — they call ``run_diffusion_worker`` directly and exit + via ``sys.exit(0)``. These ranks never reach ``DiffusionRemoteClient``. + """ + + def __init__( + self, + args: VisualGenArgs, + ): + self.args = args + self.n_workers = args.parallel_config.n_workers + + # --- Detect external launcher (torchrun / srun) --- + ext = _detect_external_launch() + + if ext is None: + # Single-node: coordinator spawns all workers locally + # Setup distributed env + self.master_addr = "127.0.0.1" + self.master_port = find_free_port() + + # Setup IPC addresses + self.host_ip = get_ip_address() + req_port, resp_port = find_free_port(), find_free_port() + + self.request_queue_addr = f"tcp://0.0.0.0:{req_port}" + self.response_queue_addr = f"tcp://0.0.0.0:{resp_port}" + self.req_addr_connect = f"tcp://{self.host_ip}:{req_port}" + self.resp_addr_connect = f"tcp://{self.host_ip}:{resp_port}" + + else: + # rank == 0 guaranteed — ranks 1..N-1 exited in VisualGen.__init__ + rank, local_rank, world_size, master_addr, master_port = ext + req_port = find_free_port() + resp_port = find_free_port() + self.master_addr = master_addr + self.master_port = master_port + self.request_queue_addr = f"tcp://0.0.0.0:{req_port}" + self.response_queue_addr = f"tcp://0.0.0.0:{resp_port}" + self.req_addr_connect = f"tcp://{master_addr}:{req_port}" + self.resp_addr_connect = f"tcp://{master_addr}:{resp_port}" + + # Generate shared HMAC keys for IPC authentication + self.req_hmac_key = os.urandom(32) + self.resp_hmac_key = os.urandom(32) + + # IPC setup + self.requests_ipc = None + self.responses_ipc = None + self.pending_requests = queue.Queue() + self.completed_responses: Dict[int, DiffusionResponse] = {} + # Request ids the caller has given up on (e.g., aresult timed out). + # _store_response drops late-arriving responses for these ids so a + # full PipelineOutput tensor does not pin in completed_responses for + # the process lifetime. + self._abandoned_request_ids: Set[int] = set() + + # We'll create asyncio primitives in the background thread's event loop + self._event_loop = None + self.response_event = None + self.lock = None + self.shutdown_event = threading.Event() + self.event_loop_ready = threading.Event() + + # Start background thread (it will create its own event loop) + self.background_thread = threading.Thread(target=self._serve_forever_thread, daemon=True) + self.background_thread.start() + + # Wait for the background thread to initialize the event loop + self.event_loop_ready.wait() + + # Pipeline metadata — populated by _wait_ready from the READY signal. + self.default_generation_params: Dict = {} + self.extra_param_specs: Dict = {} + + # --- Launch workers --- + self.worker_processes = [] + self._ext_worker_thread: Optional[threading.Thread] = None + + if ext is None: + logger.info(f"DiffusionClient: Launching {self.n_workers} workers") + ctx = mp.get_context("spawn") + for rank in range(self.n_workers): + p = ctx.Process( + target=run_diffusion_worker, + kwargs={ + "rank": rank, + "world_size": self.n_workers, + "master_addr": self.master_addr, + "master_port": self.master_port, + "request_queue_addr": self.req_addr_connect, + "response_queue_addr": self.resp_addr_connect, + "visual_gen_args": self.args, + "req_hmac_key": self.req_hmac_key, + "resp_hmac_key": self.resp_hmac_key, + "log_level": logger.level, + "local_rank": rank, + }, + ) + p.start() + self.worker_processes.append(p) + else: + # External launch: rank 0 runs its own worker in a background thread. + # Other nodes' workers are already running (they were launched by the + # external launcher and will connect to our ZMQ server once it binds). + self._ext_worker_thread = threading.Thread( + target=run_diffusion_worker, + kwargs={ + "rank": rank, + "world_size": self.n_workers, + "master_addr": master_addr, + "master_port": master_port, + "request_queue_addr": self.req_addr_connect, + "response_queue_addr": self.resp_addr_connect, + "visual_gen_args": self.args, + "req_hmac_key": self.req_hmac_key, + "resp_hmac_key": self.resp_hmac_key, + "log_level": logger.level, + "local_rank": local_rank, + }, + daemon=True, + ) + self._ext_worker_thread.start() + + self._wait_ready() + + @staticmethod + def _close_socket(ipc_queue): + if ipc_queue and ipc_queue.socket: + ipc_queue.socket.setsockopt(zmq.LINGER, 0) + ipc_queue.close() + + def enqueue_requests(self, requests: List[DiffusionRequest]) -> List[int]: + """Enqueue requests and return their IDs.""" + req_ids = [] + for req in requests: + self.pending_requests.put(req) + req_ids.append(req.request_id) + return req_ids + + async def await_responses( + self, request_ids: Union[int, List[int]], timeout: Optional[float] = None + ) -> Union[DiffusionResponse, List[DiffusionResponse]]: + """Wait for responses by request IDs. + + Args: + request_ids: Single request ID or list of request IDs to wait for + timeout: Maximum total wait time in seconds (None = wait indefinitely) + + Returns: + Single response or list of responses (None if request timed out) + """ + is_single = isinstance(request_ids, int) + ids = [request_ids] if is_single else request_ids + + start_time = time.time() + results = {} + + while len(results) < len(ids): + async with self.lock: + for req_id in ids: + if req_id in self.completed_responses: + results[req_id] = self.completed_responses.pop(req_id) + + # All responses collected + if len(results) == len(ids): + break + + # Check if overall timeout exceeded + if timeout is not None: + elapsed = time.time() - start_time + if elapsed >= timeout: + break + # Wait for remaining time or AWAIT_TIMEOUT, whichever is shorter + wait_time = min(timeout - elapsed, AWAIT_TIMEOUT) + else: + wait_time = AWAIT_TIMEOUT + + try: + await asyncio.wait_for(self.response_event.wait(), timeout=wait_time) + except asyncio.TimeoutError: + pass + self.response_event.clear() + + out = [results.get(rid) for rid in ids] + return out[0] if is_single else out + + def await_responses_sync( + self, request_ids: Union[int, List[int]], timeout: Optional[float] = None + ) -> Union[DiffusionResponse, List[DiffusionResponse]]: + """Sync wrapper to await responses from the main thread.""" + future = asyncio.run_coroutine_threadsafe( + self.await_responses(request_ids, timeout), self._event_loop + ) + return future.result(timeout=timeout if timeout else None) + + def _init_ipc(self) -> bool: + """Initialize IPC queues.""" + try: + logger.info("DiffusionClient: Initializing IPC") + self.requests_ipc = ZeroMqQueue( + (self.request_queue_addr, self.req_hmac_key), + is_server=True, + socket_type=zmq.PUSH, + use_hmac_encryption=True, + ) + self.responses_ipc = ZeroMqQueue( + (self.response_queue_addr, self.resp_hmac_key), + is_server=True, + socket_type=zmq.PULL, + use_hmac_encryption=True, + ) + logger.info("DiffusionClient: IPC ready") + return True + except Exception as e: + logger.error(f"DiffusionClient: IPC init failed: {e}") + return False + + def _send_shutdown(self): + """Send shutdown signal.""" + logger.info("DiffusionClient: Sending shutdown signal") + if self.requests_ipc: + self.requests_ipc.put(None) + self._close_socket(self.requests_ipc) + + def _process_requests(self): + """Process pending requests.""" + try: + req = self.pending_requests.get(timeout=POLL_TIMEOUT) + if req is None: + self._send_shutdown() + self.shutdown_event.set() + return + + logger.info(f"DiffusionClient: Sending request {req.request_id}") + self.requests_ipc.put(req) + except queue.Empty: + pass + except Exception as e: + logger.error(f"DiffusionClient: Error sending request: {e}") + logger.error(traceback.format_exc()) + + def _process_responses(self): + """Poll and process responses.""" + try: + if self.responses_ipc.poll(timeout=POLL_TIMEOUT): + response = self.responses_ipc.get() + if isinstance(response, DiffusionResponse): + if response.request_id == -1: + logger.info("DiffusionClient: Received READY signal") + + # Schedule the lock acquisition and event setting in the event loop + asyncio.run_coroutine_threadsafe( + self._store_response(response), self._event_loop + ) + except Exception as e: + logger.error(f"DiffusionClient: Error processing response: {e}") + + async def _store_response(self, response: DiffusionResponse): + """Store response in the completed_responses dict (async helper). + + Drops the response if the request id has been abandoned so that + late-arriving responses for timed-out requests do not leak into + ``completed_responses`` for the process lifetime. + """ + async with self.lock: + if response.request_id in self._abandoned_request_ids: + self._abandoned_request_ids.discard(response.request_id) + return + self.completed_responses[response.request_id] = response + self.response_event.set() + + async def abandon_request_id(self, request_id: int): + """Mark a request id as abandoned and drop any cached response. + + Called from the result handle's timeout branch to prevent the + executor from holding a full ``PipelineOutput`` for a request whose + caller has stopped waiting. Handles both orderings: + + - Response already arrived between the timeout firing and the + abandon call → ``pop`` releases it here. + - Response arrives after the abandon call → ``_store_response`` + checks the abandoned set and drops it on arrival. + """ + async with self.lock: + self.completed_responses.pop(request_id, None) + self._abandoned_request_ids.add(request_id) + + def _cleanup_ipc(self): + """Cleanup IPC.""" + logger.info("DiffusionClient: Cleaning up IPC") + self._close_socket(self.requests_ipc) + self._close_socket(self.responses_ipc) + + def _serve_forever_thread(self): + """Background thread wrapper that creates and runs an event loop.""" + logger.info("DiffusionClient: Background thread started") + + # Create a new event loop for this thread + self._event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._event_loop) + + # Create async primitives in this thread's event loop + self.response_event = asyncio.Event() + self.lock = asyncio.Lock() + + # Signal that the event loop is ready + self.event_loop_ready.set() + + # Run the async serve_forever + try: + self._event_loop.run_until_complete(self._serve_forever()) + finally: + self._event_loop.close() + logger.info("DiffusionClient: Background thread stopped") + + async def _serve_forever(self): + """Background thread main loop (async version).""" + if not self._init_ipc(): + return + + while not self.shutdown_event.is_set(): + self._process_requests() + self._process_responses() + await asyncio.sleep(0.001) # Yield control to allow other coroutines to run + + self._cleanup_ipc() + + def shutdown(self): + """Shutdown client and workers.""" + logger.info("DiffusionClient: Shutting down") + self.pending_requests.put(None) + + self.background_thread.join(timeout=THREAD_TIMEOUT) + if self.background_thread.is_alive(): + logger.warning("DiffusionClient: Force stopping background thread") + self.shutdown_event.set() + self.background_thread.join(timeout=1.0) + + # Shutdown workers + logger.info("DiffusionClient: Stopping workers") + for p in self.worker_processes: + p.join(timeout=WORKER_TIMEOUT) + if p.is_alive(): + logger.warning(f"DiffusionClient: Terminating worker {p.pid} with SIGTERM") + p.terminate() + p.join(timeout=WORKER_TIMEOUT) + if p.is_alive(): + logger.warning(f"DiffusionClient: Force killing worker {p.pid} with SIGKILL") + p.kill() + p.join(timeout=WORKER_TIMEOUT) + + # External-launch mode: join rank-0 worker thread + if self._ext_worker_thread is not None and self._ext_worker_thread.is_alive(): + self._ext_worker_thread.join(timeout=WORKER_TIMEOUT) + + def _wait_ready(self): + """Wait for workers to be ready (sync wrapper for async operation).""" + logger.info("DiffusionClient: Waiting for workers") + + future = asyncio.run_coroutine_threadsafe(self._wait_ready_async(), self._event_loop) + try: + future.result() + except Exception: + self.shutdown() + raise + + async def _wait_ready_async(self): + """Wait for workers to be ready (async version). + + Polls indefinitely for the ready signal. If any worker process dies + during initialization, raises RuntimeError immediately (LLM-style). + """ + start_time = time.time() + last_log_time = start_time + log_interval = 300 + + while True: + async with self.lock: + if -1 in self.completed_responses: + ready_resp = self.completed_responses.pop(-1) + # Extract pipeline metadata from the READY payload. + payload = ready_resp.output + if isinstance(payload, dict): + self.default_generation_params = payload.get( + "default_generation_params", {} + ) + self.extra_param_specs = payload.get("extra_param_specs", {}) + elapsed = time.time() - start_time + logger.info(f"DiffusionClient: Workers ready ({elapsed:.1f}s)") + return + + worker_dead = any(not p.is_alive() for p in self.worker_processes) + ext_dead = ( + self._ext_worker_thread is not None and not self._ext_worker_thread.is_alive() + ) + if worker_dead or ext_dead: + raise RuntimeError("DiffusionClient: Worker died during initialization") + + now = time.time() + if now - last_log_time >= log_interval: + elapsed = now - start_time + logger.info(f"DiffusionClient: Still waiting for workers ({elapsed:.0f}s elapsed)") + last_log_time = now + + try: + await asyncio.wait_for(self.response_event.wait(), timeout=AWAIT_TIMEOUT) + except asyncio.TimeoutError: + pass + self.response_event.clear() diff --git a/tensorrt_llm/_torch/visual_gen/models/cosmos3/pipeline_cosmos3.py b/tensorrt_llm/_torch/visual_gen/models/cosmos3/pipeline_cosmos3.py index 732dc5d4fa9d..f22b07c9d208 100644 --- a/tensorrt_llm/_torch/visual_gen/models/cosmos3/pipeline_cosmos3.py +++ b/tensorrt_llm/_torch/visual_gen/models/cosmos3/pipeline_cosmos3.py @@ -443,6 +443,7 @@ def _decode_latents(self, latents): def forward( self, prompt: Union[str, List[str]], + seed: int, negative_prompt: Optional[str] = None, image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None, height: int = COSMOS3_720P_PARAMS["height"], @@ -450,7 +451,6 @@ def forward( num_frames: int = COSMOS3_720P_PARAMS["num_frames"], num_inference_steps: int = COSMOS3_720P_PARAMS["num_inference_steps"], guidance_scale: float = COSMOS3_720P_PARAMS["guidance_scale"], - seed: int = 42, max_sequence_length: int = COSMOS3_720P_PARAMS["max_sequence_length"], frame_rate: float = COSMOS3_720P_PARAMS["frame_rate"], use_duration_template: bool = COSMOS3_EXTRA_SPECS["use_duration_template"].default, diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py index 196c3c927e7a..cf2dd468a6df 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py @@ -264,11 +264,11 @@ def infer(self, req): def forward( self, prompt: Union[str, List[str]], + seed: int, height: int = 1024, width: int = 1024, num_inference_steps: int = 50, guidance_scale: float = 3.5, - seed: int = 42, max_sequence_length: int = 512, num_images_per_prompt: int = 1, ): diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py index 82302aac8ce3..19db372cd310 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py @@ -357,11 +357,11 @@ def infer(self, req): def forward( self, prompt: Union[str, List[str]], + seed: int, height: int = 1024, width: int = 1024, num_inference_steps: int = 50, guidance_scale: float = 3.5, - seed: int = 42, max_sequence_length: int = 512, num_images_per_prompt: int = 1, ): diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py index f48b3d05aea7..aa1fb75c360b 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py @@ -1275,7 +1275,6 @@ def default_generation_params(self): "max_sequence_length": 1024, "num_frames": 121, "frame_rate": 24.0, - "image_cond_strength": 1.0, } @property @@ -1291,6 +1290,11 @@ def extra_param_specs(self): default=0.0, description="Guidance rescale factor to prevent overexposure.", ), + "image_cond_strength": ExtraParamSchema( + type="float", + default=1.0, + description="Image conditioning strength for I2V (1.0 = fully conditioned first frame).", + ), "stg_scale": ExtraParamSchema( type="float", default=0.0, @@ -1340,7 +1344,7 @@ def infer(self, req): guidance_rescale=extra["guidance_rescale"], max_sequence_length=req.params.max_sequence_length, image=req.params.image, - image_cond_strength=req.params.image_cond_strength, + image_cond_strength=extra["image_cond_strength"], stg_scale=extra["stg_scale"], stg_blocks=extra["stg_blocks"], modality_scale=extra["modality_scale"], @@ -1353,7 +1357,7 @@ def infer(self, req): # Prompt enhancement # ------------------------------------------------------------------ - def _enhance_prompt(self, prompt: str, seed: int = 42) -> str: + def _enhance_prompt(self, prompt: str, seed: int) -> str: """Use Gemma3 as an LLM to enhance the prompt for video generation.""" system_prompt = ( "You are a helpful assistant that enhances text prompts for video generation. " @@ -1392,6 +1396,7 @@ def _enhance_prompt(self, prompt: str, seed: int = 42) -> str: def forward( self, prompt: Union[str, List[str]], + seed: int, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 768, @@ -1400,7 +1405,6 @@ def forward( num_inference_steps: int = 40, guidance_scale: float = 4.0, guidance_rescale: float = 0.0, - seed: int = 42, output_type: str = "pt", max_sequence_length: int = 1024, image: Optional[Union[str, torch.Tensor]] = None, diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2_two_stages.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2_two_stages.py index 776ca9885a78..c6e6bd9b8f72 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2_two_stages.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2_two_stages.py @@ -735,7 +735,7 @@ def infer(self, req): guidance_rescale=extra["guidance_rescale"], max_sequence_length=req.params.max_sequence_length, image=req.params.image, - image_cond_strength=req.params.image_cond_strength, + image_cond_strength=extra["image_cond_strength"], stg_scale=extra["stg_scale"], stg_blocks=extra["stg_blocks"], modality_scale=extra["modality_scale"], @@ -752,6 +752,7 @@ def infer(self, req): def forward( self, prompt: Union[str, List[str]], + seed: int, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 768, @@ -760,7 +761,6 @@ def forward( num_inference_steps: int = 40, guidance_scale: float = 3.0, guidance_rescale: float = 0.0, - seed: int = 42, output_type: str = "pt", max_sequence_length: int = 1024, image: Optional[Union[str, torch.Tensor]] = None, diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/defaults.py b/tensorrt_llm/_torch/visual_gen/models/wan/defaults.py index 046331db1ec7..5ceacfe8af4c 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/defaults.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/defaults.py @@ -110,7 +110,6 @@ def get_wan_default_params( num_heads: int, *, is_wan22_5b: bool = False, - include_i2v: bool = False, ) -> dict: """Return the default generation params dict for a Wan model. @@ -119,7 +118,6 @@ def get_wan_default_params( name_or_path: Checkpoint path or HF model ID (_name_or_path). num_heads: Number of attention heads from transformer config. is_wan22_5b: Whether this is a Wan 2.2 TI2V-5B model. - include_i2v: If True, add I2V-specific defaults (image_cond_strength). """ if is_wan22_5b: params = dict(_WAN22_5B_PARAMS) @@ -130,9 +128,6 @@ def get_wan_default_params( else: params = dict(_WAN21_720P_PARAMS) - if include_i2v: - params["image_cond_strength"] = 1.0 - return params diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index 1cda852093f1..73036ee8ad4a 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -388,6 +388,7 @@ def infer(self, req): def forward( self, prompt: Union[str, List[str]], + seed: int, negative_prompt: Optional[str] = None, height: int = 720, width: int = 1280, @@ -396,7 +397,6 @@ def forward( guidance_scale: Optional[float] = None, guidance_scale_2: Optional[float] = None, boundary_ratio: Optional[float] = None, - seed: int = 42, max_sequence_length: int = 512, image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None, ): diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index 2ff09e153a03..bb27e8b017d0 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -370,7 +370,6 @@ def default_generation_params(self): is_wan22_14b=self.is_wan22_14b, name_or_path=getattr(self.config, "_name_or_path", ""), num_heads=getattr(self.config, "num_attention_heads", 40), - include_i2v=True, ) @property @@ -417,6 +416,7 @@ def forward( self, image: Union[PIL.Image.Image, torch.Tensor, str], prompt: Union[str, List[str]], + seed: int, negative_prompt: Optional[str] = None, height: int = 480, width: int = 832, @@ -425,7 +425,6 @@ def forward( guidance_scale: Optional[float] = None, guidance_scale_2: Optional[float] = None, boundary_ratio: Optional[float] = None, - seed: int = 42, max_sequence_length: int = 512, last_image: Optional[Union[PIL.Image.Image, torch.Tensor, str]] = None, ): diff --git a/tensorrt_llm/_torch/visual_gen/output.py b/tensorrt_llm/_torch/visual_gen/output.py index 3ab3686b3aec..62e8dea63a86 100644 --- a/tensorrt_llm/_torch/visual_gen/output.py +++ b/tensorrt_llm/_torch/visual_gen/output.py @@ -162,7 +162,10 @@ def to_visual_gen_output(resp: "DiffusionResponse") -> "VisualGenOutput": from tensorrt_llm.visual_gen.output import VisualGenMetrics, VisualGenOutput if resp.error_msg is not None: - return VisualGenOutput(request_id=resp.request_id, error=resp.error_msg) + return VisualGenOutput( + request_id=resp.request_id, + error=resp.error_msg, + ) out = resp.output metrics = VisualGenMetrics( generation=resp.generation, @@ -197,7 +200,10 @@ def split_visual_gen_output(resp: "DiffusionResponse", batch_size: int) -> List[ if resp.error_msg is not None: return [ - VisualGenOutput(request_id=resp.request_id, error=resp.error_msg) + VisualGenOutput( + request_id=resp.request_id, + error=resp.error_msg, + ) for _ in range(batch_size) ] out = resp.output diff --git a/tensorrt_llm/media/encoding.py b/tensorrt_llm/media/encoding.py index b0e240484625..8a2c4c57b3a2 100644 --- a/tensorrt_llm/media/encoding.py +++ b/tensorrt_llm/media/encoding.py @@ -472,8 +472,10 @@ def save_image( """Encode and save an image tensor to disk. Args: - image: Image as ``torch.Tensor`` ``(H, W, C)`` or ``(B, H, W, C)``, - dtype ``uint8``. If batched, the first image is saved. + image: Image as ``torch.Tensor`` ``(H, W, C)`` or ``(B, H, W, C)`` + with ``B == 1``, dtype ``uint8``. Batched inputs with + ``B > 1`` are rejected — pass a single tensor or use + :func:`save_images` with one path per batch item. output_path: Output file path (``str`` or :class:`pathlib.Path`). format: Image format (``'png'``/``'jpg'``/``'webp'``). If ``None``, inferred from the path extension; defaults to PNG when unknown. @@ -481,11 +483,20 @@ def save_image( Returns: Path string where the image was actually saved. + + Raises: + ValueError: When ``image`` is a 4-D tensor with ``B > 1``. """ if isinstance(output_path, Path): output_path = str(output_path) if hasattr(image, "dim") and image.dim() == 4: + if image.shape[0] > 1: + raise ValueError( + f"save_image received a batched tensor of size {image.shape[0]}; " + "pass a single (H, W, C) tensor, or use save_images(images, paths) " + "with one path per batch item." + ) image = image[0] output_dir = os.path.dirname(output_path) if output_dir: @@ -530,8 +541,10 @@ def save_video( """Encode and save a video tensor (with optional audio) to disk. Args: - video: Video as ``torch.Tensor`` ``(T, H, W, C)`` or ``(B, T, H, W, C)``, - dtype ``uint8``. If batched, the first video is saved. + video: Video as ``torch.Tensor`` ``(T, H, W, C)`` or ``(B, T, H, W, C)`` + with ``B == 1``, dtype ``uint8``. Batched inputs with + ``B > 1`` are rejected — pass a single tensor or use + :func:`save_videos` with one path per batch item. output_path: Output file path (``str`` or :class:`pathlib.Path`). audio: Optional audio tensor; ignored by the pure-Python AVI fallback. frame_rate: Frames per second. @@ -542,10 +555,19 @@ def save_video( Returns: Path string where the video was actually saved. + + Raises: + ValueError: When ``video`` is a 5-D tensor with ``B > 1``. """ if isinstance(output_path, Path): output_path = str(output_path) if hasattr(video, "dim") and video.dim() == 5: + if video.shape[0] > 1: + raise ValueError( + f"save_video received a batched tensor of size {video.shape[0]}; " + "pass a single (T, H, W, C) tensor, or use save_videos(videos, paths) " + "with one path per batch item." + ) video = video[0] output_dir = os.path.dirname(output_path) diff --git a/tensorrt_llm/media/tensor_payload.py b/tensorrt_llm/media/tensor_payload.py new file mode 100644 index 000000000000..b61580d9f439 --- /dev/null +++ b/tensorrt_llm/media/tensor_payload.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tensor-format serializers for :class:`VisualGenOutput`. + +Two payload formats are supported: + +- ``"safetensors"``: writes a single file with named tensors + (``image``/``video``/``audio``). Scalar metadata (``frame_rate``, + ``audio_sample_rate``) is stored two ways: as a 0-d tensor under + the same key (so ``safetensors.torch.load(bytes)`` returns it + alongside the media tensors — consumers call ``.item()`` to + unbox) and as a stringified value in the file header (preserved + for callers using ``safe_open(...).metadata()``). No pickle on + load. +- ``"pt"``: writes a single file via :func:`torch.save` with the + same tensor keys plus scalar metadata as native Python values. + Clients should load with ``torch.load(buf, weights_only=True)`` + on PyTorch 2.4+. + +Both serializers share the same logical payload shape so the +serve layer can pick the bytes-based path (``b64_json`` transport) +or the file-based path (``url`` transport) without having to +reconstruct the payload twice. +""" + +from __future__ import annotations + +import io +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +import torch + +if TYPE_CHECKING: + from tensorrt_llm.visual_gen.output import VisualGenOutput + + +# Tokens recognized by the tensor-payload path. Anything outside this +# set is treated as a media-encoder request by :meth:`VisualGenOutput.save`. +TENSOR_FORMATS = frozenset({"safetensors", "pt"}) + + +def is_tensor_format(fmt: Optional[str]) -> bool: + """Return True when *fmt* names a tensor payload (safetensors / pt).""" + return fmt in TENSOR_FORMATS + + +# Ranks at which each modality is batched. An image tensor with the +# canonical ``(H, W, C)`` shape is unbatched at rank 3 and batched at +# rank 4; video is unbatched at rank 4 ``(T, H, W, C)`` and batched at +# rank 5; audio is unbatched at rank 2 ``(channels, T_audio)`` and +# batched at rank 3. The serializer uses these to decide whether a +# media tensor has a true batch axis to slice along. +_BATCHED_RANKS: Dict[str, int] = { + "image": 4, + "video": 5, + "audio": 3, +} + + +def _modalities(output: "VisualGenOutput") -> Tuple[Tuple[str, Optional[torch.Tensor]], ...]: + return ( + ("image", output.image), + ("video", output.video), + ("audio", output.audio), + ) + + +def infer_batch_size(output: "VisualGenOutput") -> int: + """Return the leading batch dimension across the populated media tensors. + + Image is batched only at rank 4, video at rank 5, audio at rank 3. + An unbatched media tensor reports a batch size of 1 so list-path + callers can still ask for ``[0]`` and get a single-item payload. + Raises :class:`ValueError` when *output* carries no media tensor. + """ + sizes = set() + have_media = False + for name, tensor in _modalities(output): + if tensor is None: + continue + have_media = True + if tensor.dim() == _BATCHED_RANKS[name]: + sizes.add(int(tensor.shape[0])) + else: + sizes.add(1) + if not have_media: + raise ValueError( + f"Cannot infer batch size: request {output.request_id} carries no media tensor." + ) + if len(sizes) > 1: + raise ValueError( + f"Inconsistent batch sizes across modalities: {sorted(sizes)}. " + "All populated media tensors must agree on the leading batch axis." + ) + return next(iter(sizes)) + + +def _collect_tensors_and_metadata( + output: "VisualGenOutput", + batch_index: Optional[int], + *, + frame_rate_override: Optional[float] = None, + audio_sample_rate_override: Optional[int] = None, +) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Build the ``{name: tensor}`` and ``{name: scalar}`` views over *output*. + + When ``batch_index`` is provided, slice each populated media tensor + along its leading batch dimension *only if* the tensor's rank says + it actually has a batch axis (see :data:`_BATCHED_RANKS`). Tensors + that are already unbatched are passed through unchanged so callers + don't accidentally slice height (image) or frame (video) axes. + + When ``batch_index`` is ``None`` the tensors are written as-is. + + ``frame_rate_override`` / ``audio_sample_rate_override`` take + precedence over the corresponding fields on *output*. This mirrors + the video-encoder path's ``frame_rate``/``audio_sample_rate`` + keyword overrides on :meth:`VisualGenOutput.save` so callers that + pass overrides (because the output fields are ``None`` or need to + be replaced) get those values in the serialized payload's + metadata. + """ + tensors: Dict[str, torch.Tensor] = {} + metadata: Dict[str, Any] = {} + + for name, src in _modalities(output): + if src is None: + continue + if batch_index is not None and src.dim() == _BATCHED_RANKS[name]: + sliced = src[batch_index] + else: + sliced = src + tensors[name] = sliced.contiguous().cpu() + + frame_rate = frame_rate_override if frame_rate_override is not None else output.frame_rate + audio_sample_rate = ( + audio_sample_rate_override + if audio_sample_rate_override is not None + else output.audio_sample_rate + ) + if frame_rate is not None: + metadata["frame_rate"] = float(frame_rate) + if audio_sample_rate is not None: + metadata["audio_sample_rate"] = int(audio_sample_rate) + + return tensors, metadata + + +def serialize_visual_gen_output( + output: "VisualGenOutput", + fmt: str, + *, + batch_index: Optional[int] = None, + frame_rate: Optional[float] = None, + audio_sample_rate: Optional[int] = None, +) -> bytes: + """Serialize *output* to in-memory bytes using *fmt*. + + Args: + output: The :class:`VisualGenOutput` to serialize. Must have at + least one populated media tensor. + fmt: ``"safetensors"`` or ``"pt"``. + batch_index: When set, slice each populated tensor along its + leading batch dimension before serialization so the result + corresponds to a single batch item. + frame_rate: Override the ``frame_rate`` written into the + payload metadata. Falls back to ``output.frame_rate``. + audio_sample_rate: Override the ``audio_sample_rate`` written + into the payload metadata. Falls back to + ``output.audio_sample_rate``. + + Returns: + Serialized bytes ready for ``b64_json`` transport or for writing + to disk via :func:`save_visual_gen_output_payload`. + + Raises: + ValueError: When *fmt* is not a supported tensor token or + *output* carries no media tensor. + """ + if not is_tensor_format(fmt): + raise ValueError( + f"Unsupported tensor format: {fmt!r}. Use one of {sorted(TENSOR_FORMATS)}." + ) + + tensors, metadata = _collect_tensors_and_metadata( + output, + batch_index, + frame_rate_override=frame_rate, + audio_sample_rate_override=audio_sample_rate, + ) + if not tensors: + raise ValueError( + f"Cannot serialize output: request {output.request_id} carries no media tensor." + ) + + if fmt == "safetensors": + from safetensors.torch import save as safetensors_save + + # Store each scalar twice: as a 0-d tensor (survives the canonical + # ``safetensors.torch.load(bytes)`` path so consumers can read + # ``loaded["frame_rate"].item()`` directly) and as a string in the + # file header (preserved for callers that already use + # ``safe_open(...).metadata()``). The two views always agree. + scalar_tensors = {k: torch.as_tensor(v) for k, v in metadata.items()} + return safetensors_save( + {**tensors, **scalar_tensors}, + metadata={k: str(v) for k, v in metadata.items()}, + ) + + payload: Dict[str, Any] = {**tensors, **metadata} + buf = io.BytesIO() + torch.save(payload, buf) + return buf.getvalue() + + +def save_visual_gen_output_payload( + output: "VisualGenOutput", + path: Union[str, Path], + fmt: str, + *, + batch_index: Optional[int] = None, + frame_rate: Optional[float] = None, + audio_sample_rate: Optional[int] = None, +) -> Path: + """Write the tensor payload for *output* to *path* using *fmt*. + + The path's suffix is normalized to ``.safetensors`` or ``.pt`` when + missing so on-disk artifacts are always identifiable by extension. + ``frame_rate`` and ``audio_sample_rate`` override the + corresponding fields on *output* in the serialized payload's + metadata, matching :meth:`VisualGenOutput.save`'s encoder path. + """ + target = Path(path) + if target.suffix == "": + target = target.with_suffix(f".{fmt}") + target.parent.mkdir(parents=True, exist_ok=True) + data = serialize_visual_gen_output( + output, + fmt, + batch_index=batch_index, + frame_rate=frame_rate, + audio_sample_rate=audio_sample_rate, + ) + target.write_bytes(data) + return target diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 59cc2f2d7295..71595bf5582f 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -2,7 +2,6 @@ # https://github.com/vllm-project/vllm/blob/4db5176d9758b720b05460c50ace3c01026eb158/vllm/entrypoints/openai/protocol.py import base64 import math -import re import time import uuid from typing import Any, Dict, List, Literal, Optional, Union @@ -1350,57 +1349,79 @@ def to_llm_disaggregated_params( class ImageGenerationRequest(OpenAIBaseModel): """OpenAI-compatible image generation request. - Follows the OpenAI Images API specification: - https://platform.openai.com/docs/api-reference/images/create + Universal per-request fields map 1:1 to :class:`VisualGenParams`. + Model-specific knobs (``stg_scale``, ``guidance_rescale``, …) + travel through ``extra_params``; the executor validates each + key against the loaded pipeline's + ``extra_param_specs``. Unknown top-level fields are rejected + with HTTP 422 via the inherited ``extra="forbid"`` policy. """ + + # Prompt + transport (OpenAI-standard, always honored) prompt: str - model: Optional[str] = None - n: int = Field(default=1, ge=1, le=10) - output_format: Literal["png", "webp", "jpeg"] = "png" - size: Optional[str] = Field( - default="auto", - description=( - "The size of the generated images. Must be in 'WxH' format like " - "1024x1024, 1536x1024 (landscape), 1024x1536 (portrait), etc. " - "Use 'auto' for model default size.")) - quality: Literal["standard", "hd"] = "standard" response_format: Literal["url", "b64_json"] = "url" - style: Optional[Literal["vivid", "natural"]] = "vivid" - user: Optional[str] = None + format: Literal["png", "webp", "jpeg", "safetensors", "pt"] = Field( + default="png", + description=( + "Generation content encoding format. Image encoders write " + "``png``/``webp``/``jpeg``; tensor encoders write " + "``safetensors``/``pt`` for programmatic post-processing."), + ) + seed: Optional[int] = Field(default=None, + ge=0, + description="Random seed for reproducibility.") - # Extended parameters for diffusion control - num_inference_steps: Optional[int] = Field( + # Resolution. ``size`` is OpenAI-shaped; ``width`` + ``height`` are an + # equivalent structured alternative. Exactly one of width/height is + # rejected by the paired validator below. Numeric fields use + # ``gt=0`` as a safety net so zero / negative inputs are rejected + # with HTTP 422 before reaching the pipeline. + size: Optional[str] = Field(default=None, pattern=r"^(\d+x\d+|auto)$") + width: Optional[int] = Field(default=None, gt=0) + height: Optional[int] = Field(default=None, gt=0) + + # TRT-LLM-supported per-request params (1:1 with VisualGenParams fields) + num_inference_steps: Optional[int] = Field(default=None, gt=0) + guidance_scale: Optional[float] = Field(default=None, gt=0) + max_sequence_length: Optional[int] = Field(default=None, gt=0) + negative_prompt: Optional[str] = None + n: Optional[int] = Field( default=None, - description= - "Number of denoising steps. More steps = higher quality but slower.") - guidance_scale: Optional[float] = Field( - default=None, - description= - "Classifier-free guidance scale. Higher values follow prompt more closely." + gt=0, + le=10, + description=("Number of images to generate. Capped at 10 to match the " + "OpenAI images API and to bound GPU memory / disk usage."), ) - guidance_rescale: Optional[float] = Field( - default=None, description="Classifier-free guidance rescale.") - negative_prompt: Optional[str] = Field( + + # Model-specific overflow + extra_params: Optional[Dict[str, Any]] = Field( default=None, - description="Text describing what to avoid in the generated image.") - seed: Optional[int] = Field(default=None, - description="Random seed for reproducibility.") + description=( + "Model-specific parameters forwarded to the underlying pipeline. " + "See per-model docs for accepted keys."), + ) - @field_validator("size") - @classmethod - def validate_size(cls, v): - """Validate size format is 'WxH' or 'auto'.""" - if v is None or v == "auto": - return v - if not isinstance(v, str): - raise ValueError("size must be a string in 'WxH' format or 'auto'") - # Check format: should be like "1024x1024" - import re - if not re.match(r'^\d+x\d+$', v): + # Accepted-but-ignored OpenAI-shaped fields. The conversion no-ops; the + # server logs WARNING when a client sets ``quality`` or ``style``, and + # WARNING-on-mismatch for ``model``. Kept in the schema so OpenAI-SDK + # clients don't trip ``extra="forbid"``. + model: Optional[str] = None + quality: Optional[Literal["standard", "hd"]] = None + style: Optional[Literal["vivid", "natural"]] = None + user: Optional[str] = None + + @model_validator(mode="after") + def _check_paired_dimensions(self): + """Reject sending exactly one of ``width`` / ``height``. + + Either both are sent (structured resolution wins over ``size``) + or neither is sent (``size`` or pipeline default applies). + """ + if (self.width is None) != (self.height is None): raise ValueError( - f"Invalid size format '{v}'. Must be in 'WxH' format " - "(e.g., '1024x1024', '1536x1024') or 'auto'.") - return v + "width and height must be sent together; got width=" + f"{self.width!r}, height={self.height!r}") + return self class ImageObject(OpenAIBaseModel): @@ -1411,119 +1432,99 @@ class ImageObject(OpenAIBaseModel): class ImageGenerationResponse(OpenAIBaseModel): - """Response from image generation endpoint.""" + """Response from image generation endpoint. + + ``output_format`` reports the encoding actually applied to the + returned bytes / files so clients can decode or label the payload + correctly. Image encoders are ``"png"``/``"webp"``/``"jpeg"``; + tensor formats are ``"safetensors"``/``"pt"``. + """ + created: int = Field(default_factory=lambda: int(time.time())) data: List[ImageObject] - output_format: Literal["png", "webp", "jpeg"] = "png" + output_format: Literal["png", "webp", "jpeg", "safetensors", "pt"] = "png" quality: Literal["low", "medium", "high"] = "medium" size: Optional[str] = None -class ImageEditRequest(OpenAIBaseModel): - """Request for image editing endpoint. +class VideoGenerationRequest(OpenAIBaseModel): + """Video generation request (extended API). - Follows the OpenAI Images API specification: - https://platform.openai.com/docs/api-reference/images/createEdit + Universal per-request fields map 1:1 to :class:`VisualGenParams`. + Model-specific knobs travel through ``extra_params``. Unknown + top-level fields are rejected with HTTP 422 via the inherited + ``extra="forbid"`` policy. """ - image: Union[List[str], str] = Field( - description="Base64-encoded source image(s) to edit") - prompt: str = Field(description="Text description of desired edits") - model: Optional[str] = None - mask: Optional[str] = Field( - default=None, - description= - "Base64-encoded mask image (optional, black areas will be edited)") - n: int = Field(default=1, ge=1, le=10) - size: Optional[str] = Field( + + # Prompt + transport + prompt: str + response_format: Literal["url", "b64_json"] = "url" + format: Literal["mp4", "avi", "auto", "safetensors", "pt"] = Field( default="auto", description=( - "The size of the edited images. Must be in 'WxH' format like " - "1024x1024, 1536x1024 (landscape), 1024x1536 (portrait), etc. " - "Use 'auto' to match source image size.")) - response_format: Literal["url", "b64_json"] = "url" - user: Optional[str] = None - - # Extended parameters for diffusion control - num_inference_steps: Optional[int] = Field( - default=None, description="Number of denoising steps.") - guidance_scale: Optional[float] = Field( - default=None, description="Classifier-free guidance scale.") - guidance_rescale: Optional[float] = Field( - default=None, description="Classifier-free guidance rescale.") - negative_prompt: Optional[str] = Field( - default=None, - description="Text describing what to avoid in the edited image.") + "Generation content encoding format. Video encoders write " + "``mp4``/``avi``/``auto``; tensor encoders write " + "``safetensors``/``pt`` and carry video, audio, and scalar " + "metadata (frame rate, audio sample rate) in one payload."), + ) seed: Optional[int] = Field(default=None, + ge=0, description="Random seed for reproducibility.") - - @field_validator("size") - @classmethod - def validate_size(cls, v): - """Validate size format is 'WxH' or 'auto'.""" - if v != "auto" and not re.match(r"^\d+x\d+$", v): - raise ValueError( - "Size must be 'auto' or in 'WxH' format (e.g., '1024x1024')") - return v - - -class VideoGenerationRequest(OpenAIBaseModel): - """Video generation request (extended API). - - This is an extension to the OpenAI API for video generation support. - """ - prompt: str input_reference: Optional[Union[str, UploadFile]] = Field( default=None, - description="Optional image reference that guides generation.") - model: Optional[str] = None - size: Optional[str] = Field( - default="auto", - description= - ("The size of the generated video frames. Must be in 'WxH' format like " - "512x512, 1024x576 (landscape), 576x1024 (portrait), etc. " - "Use 'auto' for model default size.")) - seconds: float = Field(default=2.0, - ge=1.0, - le=16.0, - description="Video duration in seconds.") - - # Extended parameters for diffusion control - n: int = Field(default=1, ge=1, le=4) - fps: int = Field(default=24, ge=8, le=60, description="Frames per second.") - num_inference_steps: Optional[int] = Field( - default=None, description="Number of denoising steps.") - guidance_scale: Optional[float] = Field( - default=None, description="Classifier-free guidance scale.") - guidance_rescale: Optional[float] = Field( - default=None, description="Classifier-free guidance rescale.") - negative_prompt: Optional[str] = Field( + description="Optional image reference that guides generation.", + ) + + # Resolution + size: Optional[str] = Field(default=None, pattern=r"^(\d+x\d+|auto)$") + width: Optional[int] = Field(default=None, gt=0) + height: Optional[int] = Field(default=None, gt=0) + + # Frame budget. ``num_frames`` is preferred; if absent the engine + # derives it from ``seconds * frame_rate``. ``frame_rate`` is the + # canonical name (matches the Python field); ``fps`` is an alias for + # OpenAI-shape clients via ``populate_by_name=True``. + # All three constrain to strictly positive values so a zero + # ``frame_rate`` (division-by-zero in the AVI fallback) or a + # negative ``num_frames`` are rejected with HTTP 422 before + # reaching the encoder. + # Upper bounds keep request-boundary protection against requests + # that can exhaust GPU memory or pin the server on unbounded work. + # The numbers are generous (a minute of video at 120 fps) so common + # workloads pass; clients that need larger budgets can lift the cap + # at deployment time. + num_frames: Optional[int] = Field(default=None, gt=0, le=7200) + seconds: Optional[float] = Field(default=None, gt=0, le=60.0) + frame_rate: Optional[float] = Field(default=None, + alias="fps", + gt=0, + le=120.0) + + # TRT-LLM-supported per-request params (1:1 with VisualGenParams) + num_inference_steps: Optional[int] = Field(default=None, gt=0) + guidance_scale: Optional[float] = Field(default=None, gt=0) + max_sequence_length: Optional[int] = Field(default=None, gt=0) + negative_prompt: Optional[str] = None + + # Model-specific overflow + extra_params: Optional[Dict[str, Any]] = Field( default=None, - description="Text describing what to avoid in the generated video.") - seed: Optional[int] = Field(default=None, - description="Random seed for reproducibility.") - output_format: Literal["mp4", "avi", "auto"] = Field( - default="auto", description=( - "Video encode format. " - "'mp4' for H.264 encoding (requires ffmpeg installed on server), " - "'avi' for MJPEG encoding (always available, no audio support), " - "'auto' to use best available (H.264 if ffmpeg installed, " - "otherwise MJPEG).")) + "Model-specific parameters forwarded to the underlying pipeline. " + "See per-model docs for accepted keys."), + ) - @field_validator("size") - @classmethod - def validate_size(cls, v): - """Validate size format is 'WxH' or 'auto'.""" - if v is None or v == "auto": - return v - if not isinstance(v, str): - raise ValueError("size must be a string in 'WxH' format or 'auto'") - import re - if not re.match(r'^\d+x\d+$', v): + # Accepted-but-ignored OpenAI-shaped field + model: Optional[str] = None + + @model_validator(mode="after") + def _check_paired_dimensions(self): + """Reject sending exactly one of ``width`` / ``height``.""" + if (self.width is None) != (self.height is None): raise ValueError( - f"Invalid size format '{v}'. Must be in 'WxH' format " - "(e.g., '512x512', '1024x576') or 'auto'.") - return v + "width and height must be sent together; got width=" + f"{self.width!r}, height={self.height!r}") + return self class VideoJob(OpenAIBaseModel): @@ -1552,13 +1553,26 @@ class VideoJob(OpenAIBaseModel): # Video properties duration: Optional[float] = Field(default=None, description="Video duration in seconds") - fps: Optional[int] = Field(default=None, description="Frames per second") + fps: Optional[float] = Field( + default=None, + description=( + "Frames per second. Float to preserve cinematic rates such " + "as 23.976 / 29.97 that some encoders / pipelines use."), + ) size: Optional[str] = Field(default=None, description="Video dimensions in 'WxH' format") output_path: Optional[str] = Field( default=None, description="Actual path where the video file was saved") output_paths: Optional[List[str]] = Field( default=None, description="Paths for all generated videos when n > 1") + response_format: Optional[Literal["url", "b64_json"]] = Field( + default=None, + description=( + "Transport the client requested. ``GET /v1/videos/{id}/content`` " + "honors this: ``b64_json`` returns the encoded payload as a " + "base64 string inside a JSON envelope; ``url`` (or unset) " + "returns the file as a ``FileResponse`` download."), + ) class VideoJobList(OpenAIBaseModel): diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index ce157cb92a42..9ae68d9e63e2 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -20,7 +20,8 @@ import uvicorn from fastapi import Body, FastAPI, Request from fastapi.exceptions import RequestValidationError -from fastapi.responses import JSONResponse, Response, StreamingResponse +from fastapi.responses import (FileResponse, JSONResponse, Response, + StreamingResponse) from pydantic import ValidationError from starlette.routing import Mount from transformers import AutoProcessor @@ -44,6 +45,7 @@ add_thinking_budget_logits_processor from tensorrt_llm.logger import logger from tensorrt_llm.media.encoding import image_to_bytes +from tensorrt_llm.media.tensor_payload import is_tensor_format from tensorrt_llm.metrics.collector import MetricsCollector from tensorrt_llm.sampling_params import GuidedDecodingParams from tensorrt_llm.serve.chat_utils import (load_chat_template, @@ -58,7 +60,7 @@ ChatMessage, CompletionRequest, CompletionResponse, CompletionResponseChoice, - ErrorResponse, ImageEditRequest, + ErrorResponse, ImageGenerationRequest, ImageGenerationResponse, ImageObject, @@ -321,6 +323,12 @@ async def lifespan(app: FastAPI): @self.app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): + if self.server_role is ServerRole.VISUAL_GEN: + return self._create_visual_gen_validation_error_response(exc) + # Non-visual-gen roles keep the shared 400 + ``{"error": ...}`` + # response shape that integration tests (e.g. + # ``test_malformed_json_request``) and existing clients + # expect. if self.metrics_collector: self.metrics_collector.log_request_error(http_code=400) return JSONResponse(status_code=400, content={"error": str(exc)}) @@ -592,6 +600,36 @@ def _create_not_supported_error(self, message: str) -> Response: status_code=HTTPStatus.NOT_IMPLEMENTED, ) + def _create_visual_gen_validation_error_response( + self, exc: RequestValidationError) -> Response: + """Render a ``RequestValidationError`` as the visual-gen 422 envelope. + + The body has the LLM-style ``{message, type, code}`` shape with + HTTP 422. The ``message`` field names the offending field(s) for + each error in ``exc.errors()`` so clients can fix the request + without parsing the full Pydantic payload. + """ + parts: List[str] = [] + for err in exc.errors(): + loc = ".".join( + str(seg) for seg in err.get("loc", ()) if seg != "body") + etype = err.get("type", "") + msg = err.get("msg", "") + if etype == "extra_forbidden": + parts.append( + f"Unknown request field {loc!r}. Pass model-specific " + "parameters via 'extra_params' instead.") + elif loc: + parts.append(f"{loc}: {msg}") + else: + parts.append(msg) + message = "; ".join(parts) if parts else str(exc) + return self.create_error_response( + message=message, + err_type="BadRequestError", + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + ) + def _check_health(self) -> bool: if isinstance(self.generator, LLM): return self.generator._check_health() @@ -749,6 +787,9 @@ def register_visual_gen_routes(self): self.app.add_api_route("/v1/images/edits", self.openai_image_edit, methods=["POST"]) + self.app.add_api_route("/v1/images/{image_id}/content", + self.get_image_content, + methods=["GET"]) # Video generation endpoints (Extended OpenAI API) # Asynchronous video generation (returns immediately with job metadata, OpenAI API) @@ -1599,6 +1640,7 @@ async def generator_wrapper(generator: AsyncIterator[Any]): async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Request) -> Response: """Chat Completion API with harmony format support. + Supports both streaming and non-streaming modes. """ @@ -1944,18 +1986,34 @@ async def openai_image_generation(self, request: ImageGenerationRequest, raw_request: Request) -> Response: """OpenAI-compatible image generation endpoint. - Follows the OpenAI Images API specification for image generation. + Follows the OpenAI Images API specification for image generation, + with ``request.format`` extended to accept tensor payloads + (``"safetensors"``/``"pt"``) alongside the PNG/WebP/JPEG encoders. """ try: image_id = f"image_{uuid.uuid4().hex}" - params = parse_visual_gen_params(request, image_id, self.generator) - logger.info( - f"Generating image: {image_id} with params: {params} and prompt: {request.prompt}" - ) - image_gen_start = time.perf_counter() - output = self.generator.generate(inputs=request.prompt, - params=params) + # Client-side ValueErrors from request translation and + # parameter validation are 400. Serialization failures below + # (server-side: missing media, inconsistent batch) fall + # through to the outer ``except Exception`` → 500 so the + # client doesn't get blamed for a server-internal failure. + try: + params = parse_visual_gen_params(request, image_id, + self.generator) + logger.info( + f"Generating image: {image_id} with params: {params} and prompt: {request.prompt}" + ) + image_gen_start = time.perf_counter() + output = self.generator.generate(inputs=request.prompt, + params=params) + except ValueError as exc: + logger.error(f"Image request error: {exc}") + return self.create_error_response( + message=str(exc), + status_code=HTTPStatus.BAD_REQUEST, + ) + if output.image is None: return self.create_error_response( message="Image generation failed", @@ -1963,29 +2021,78 @@ async def openai_image_generation(self, request: ImageGenerationRequest, status_code=HTTPStatus.INTERNAL_SERVER_ERROR, ) - # Build response - output_images = _normalize_image_output(output.image) - - if request.response_format == "b64_json": - data = [ - ImageObject( - b64_json=base64.b64encode( - image_to_bytes(image)).decode('utf-8'), - revised_prompt=request.prompt, - ) for image in output_images - ] - + if is_tensor_format(request.format): + # Tensor payloads carry every populated modality in a + # single file. Match the image-encoder fan-out by + # emitting one ``ImageObject`` per batch item. + from tensorrt_llm.media.tensor_payload import infer_batch_size + + ext = f".{request.format}" + batch_size = infer_batch_size(output) + if request.response_format == "b64_json": + data = [ + ImageObject( + b64_json=base64.b64encode( + output._save_bytes( + request.format, + batch_index=i)).decode("utf-8"), + revised_prompt=request.prompt, + ) for i in range(batch_size) + ] + else: + paths_in = [ + self.media_storage_path / f"{image_id}_{i}{ext}" + for i in range(batch_size) + ] + output.save(paths_in, format=request.format) + data = [ + ImageObject( + url=self._build_image_content_url( + raw_request, image_id, i), + revised_prompt=request.prompt, + ) for i in range(batch_size) + ] response = ImageGenerationResponse( created=int(time.time()), data=data, + output_format=request.format, + size=f"{params.width}x{params.height}", + ) + else: + output_images = _normalize_image_output(output.image) + # Pillow's format name is the upper-case form of our + # request token. The on-disk extension matches the + # request token directly; ``.jpeg`` is interchangeable + # with ``.jpg`` for Pillow, the OS, and the OpenAI API. + pil_format = request.format.upper() + ext = f".{request.format}" + if request.response_format == "b64_json": + data = [ + ImageObject( + b64_json=base64.b64encode( + image_to_bytes( + image, format=pil_format)).decode("utf-8"), + revised_prompt=request.prompt, + ) for image in output_images + ] + else: + data = [] + for i, image in enumerate(output_images): + path = self.media_storage_path / f"{image_id}_{i}{ext}" + path.write_bytes( + image_to_bytes(image, format=pil_format)) + data.append( + ImageObject( + url=self._build_image_content_url( + raw_request, image_id, i), + revised_prompt=request.prompt, + )) + response = ImageGenerationResponse( + created=int(time.time()), + data=data, + output_format=request.format, size=f"{params.width}x{params.height}", ) - - elif request.response_format == "url": - output.save(self.media_storage_path / f"{image_id}.png") - # TODO: Support URL mode - return self._create_not_supported_error( - "URL mode is not supported for image generation") latency = time.perf_counter() - image_gen_start # seconds metrics = output.metrics @@ -2000,20 +2107,59 @@ async def openai_image_generation(self, request: ImageGenerationRequest, except Exception as e: logger.error(traceback.format_exc()) - return self.create_error_response(str(e)) + return self.create_error_response( + message=str(e), + err_type="InternalServerError", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) - async def openai_image_edit(self, request: ImageEditRequest, - raw_request: Request) -> Response: - """OpenAI-compatible image editing endpoint. + async def get_image_content( + self, + image_id: str, + raw_request: Request, + i: int = 0, + ) -> Response: + """Serve a generated image by ID and batch index. + + ``GET /v1/images/{image_id}/content?i=`` returns + the image file the corresponding ``POST /v1/images/generations`` + wrote into ``media_storage_path``. ``i`` defaults to ``0`` for + single-image requests; URL responses for ``n > 1`` requests + carry ``?i=N`` per item. + """ + for ext in (".png", ".webp", ".jpg", ".jpeg", ".safetensors", ".pt"): + candidate = self.media_storage_path / f"{image_id}_{i}{ext}" + if candidate.exists(): + media_type = ("application/octet-stream" + if ext in (".safetensors", ".pt") else + f"image/{ext.lstrip('.').replace('jpg', 'jpeg')}") + return FileResponse( + str(candidate), + media_type=media_type, + filename=candidate.name, + ) + return self.create_error_response( + message=f"Image {image_id!r} (batch index {i}) not found", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND, + ) + + @staticmethod + def _build_image_content_url(raw_request: Request, image_id: str, + i: int) -> str: + """Return a fetchable HTTP URL for a generated image item.""" + base = str(raw_request.base_url).rstrip("/") + return f"{base}/v1/images/{image_id}/content?i={i}" - Follows the OpenAI Images API specification for image editing. - Creates an edited or extended image given an original image and a prompt. + async def openai_image_edit(self, raw_request: Request) -> Response: + """OpenAI-compatible image editing endpoint — returns HTTP 501. No in-tree pipeline implements image editing today: Flux/Flux2 are text-to-image only and ignore ``params.image``; Wan and LTX-2 produce - video, not edited images. Return 501 here so callers get an honest - NotImplemented signal instead of a 500 from a downstream None check. - Re-enable the full handler when an edit-capable pipeline lands. + video, not edited images. The route is registered so callers get an + honest NotImplemented signal instead of a 404. The request body is + not parsed because no schema is committed for this endpoint yet — + bring a typed request model back when an edit-capable pipeline lands. """ return self._create_not_supported_error( "Image editing is not supported by any in-tree pipeline yet.") diff --git a/tensorrt_llm/serve/openai_video_routes.py b/tensorrt_llm/serve/openai_video_routes.py index 4f739201d276..9a2543bbd084 100644 --- a/tensorrt_llm/serve/openai_video_routes.py +++ b/tensorrt_llm/serve/openai_video_routes.py @@ -11,6 +11,8 @@ """ import asyncio +import base64 +import json import os import time import traceback @@ -20,15 +22,63 @@ from fastapi import Request from fastapi.responses import FileResponse, JSONResponse, Response +from pydantic import ValidationError from tensorrt_llm.logger import logger from tensorrt_llm.media.encoding import resolve_video_format +from tensorrt_llm.media.tensor_payload import is_tensor_format from tensorrt_llm.serve.openai_protocol import VideoGenerationRequest, VideoJob, VideoJobList from tensorrt_llm.serve.visual_gen_metrics import build_visual_gen_timing_headers from tensorrt_llm.serve.visual_gen_utils import VIDEO_STORE, parse_visual_gen_params from tensorrt_llm.visual_gen.params import VisualGenParams +def _video_content_type(suffix: str) -> str: + """Map a video file suffix to its HTTP ``Content-Type``.""" + if suffix == ".mp4": + return "video/mp4" + if suffix == ".avi": + return "video/x-msvideo" + return "application/octet-stream" + + +# File suffixes the GET /v1/videos/{id}/content and DELETE +# /v1/videos/{id} routes try when the stored output_path is missing. +_KNOWN_VIDEO_OUTPUT_SUFFIXES = (".mp4", ".avi", ".safetensors", ".pt") + + +def _preflight_encoder_format(fmt): + """Pre-flight an encoder format string before any GPU work. + + Returns the resolved encoder format token, or ``None`` for tensor + formats (which carry no encoder dependency). Raises ``ValueError`` + for both unsupported format strings and the missing-ffmpeg case on + ``format='mp4'`` so the route's existing 400 handler renders the + message; without this normalization the missing-ffmpeg + ``RuntimeError`` would fall through to the generic 500 handler. + """ + if is_tensor_format(fmt): + return None + try: + return resolve_video_format(fmt)[0] + except RuntimeError as exc: + raise ValueError(str(exc)) from exc + + +def _b64_json_video_response(video_id: str, fmt: str, path: Path) -> JSONResponse: + """Build the OpenAI-style ``{id, format, b64_json}`` envelope. + + Reads bytes from a saved video file on disk and base64-inlines them. + """ + return JSONResponse( + content={ + "id": video_id, + "format": fmt, + "b64_json": base64.b64encode(path.read_bytes()).decode("utf-8"), + } + ) + + class _VideoRoutesMixin: """Mixin providing the eight video-generation endpoints. @@ -48,22 +98,33 @@ async def openai_video_generation_sync(self, raw_request: Request) -> Response: - Multipart: Send form fields + optional input_reference file """ try: - # Parse request based on content-type - request = await self._parse_video_generation_request(raw_request) - - # Resolve the video encode format (mp4/avi/auto) - resolved_fmt, _ = resolve_video_format(request.output_format) - - video_id = f"video_{uuid.uuid4().hex}" - params = parse_visual_gen_params( - request, video_id, self.generator, media_storage_path=str(self.media_storage_path) - ) - logger.info( - f"Generating video: {video_id} with params: {params} and prompt: {request.prompt}" - ) + # Client-side ValueErrors from content-type parsing, request + # translation, encoder-format preflight, parameter validation, + # and the synchronous engine call return 400. Serialization / + # encoder failures further down (server-side) fall through to + # the outer ``except Exception`` → 500. + try: + # Parse request based on content-type + request = await self._parse_video_generation_request(raw_request) + video_id = f"video_{uuid.uuid4().hex}" + params = parse_visual_gen_params( + request, + video_id, + self.generator, + media_storage_path=str(self.media_storage_path), + ) + resolved_encoder_fmt = _preflight_encoder_format(request.format) + logger.info( + f"Generating video: {video_id} with params: {params} and prompt: {request.prompt}" + ) + sync_video_start = time.perf_counter() + output = self.generator.generate(inputs=request.prompt, params=params) + except ValidationError as exc: + return self._render_pydantic_validation_error(exc) + except ValueError as exc: + logger.error(f"Video request error: {exc}") + return self.create_error_response(str(exc), status_code=HTTPStatus.BAD_REQUEST) - sync_video_start = time.perf_counter() - output = self.generator.generate(inputs=request.prompt, params=params) if output.video is None: return self.create_error_response( message="Video generation failed", @@ -71,13 +132,38 @@ async def openai_video_generation_sync(self, raw_request: Request) -> Response: status_code=HTTPStatus.INTERNAL_SERVER_ERROR, ) - # Save all generated videos (batch-aware). + if is_tensor_format(request.format): + ext = f".{request.format}" + media_type = "application/octet-stream" + # Match the encoder-format path: persist one file per batch + # item, ship the first as the route's primary download + # (OpenAI sync video API does not define a multi-file + # response yet — TRTLLM-11579). + batch_size = output.video.shape[0] if output.video.dim() == 5 else 1 + tensor_paths = [ + self.media_storage_path / f"{video_id}_{i}{ext}" for i in range(batch_size) + ] + saved_paths = output.save(tensor_paths, format=request.format) + target = saved_paths[0] + latency = time.perf_counter() - sync_video_start + logger.info( + f"Video {video_id} serialized as tensor: latency={latency:.3f}s " + f"generation={getattr(output.metrics, 'generation', 0.0):.3f}s" + ) + if request.response_format == "b64_json": + return _b64_json_video_response(video_id, request.format, target) + return FileResponse(str(target), media_type=media_type, filename=target.name) + + # Encoder formats: one file per item; ship the first item as + # the route's primary download (OpenAI sync video API does + # not define a multi-file response yet — TRTLLM-11579). + resolved_fmt = resolved_encoder_fmt batch_size = output.video.shape[0] if output.video.dim() == 5 else 1 paths_in = [self.media_storage_path / f"{video_id}_{i}" for i in range(batch_size)] saved_paths = output.save( paths_in, format=resolved_fmt, - frame_rate=output.frame_rate or request.fps or params.frame_rate, + frame_rate=output.frame_rate or request.frame_rate or params.frame_rate, ) latency = time.perf_counter() - sync_video_start # seconds metrics = output.metrics @@ -94,19 +180,19 @@ async def openai_video_generation_sync(self, raw_request: Request) -> Response: # multi-file response, so we return only the first video as a file # download while persisting all of them to disk. actual_path = saved_paths[0] - actual_output_path = str(actual_path) - media_type = "video/mp4" if actual_path.suffix == ".mp4" else "video/x-msvideo" - + if request.response_format == "b64_json": + return _b64_json_video_response( + video_id, actual_path.suffix.lstrip("."), actual_path + ) return FileResponse( - actual_output_path, - media_type=media_type, + str(actual_path), + media_type=_video_content_type(actual_path.suffix), filename=actual_path.name, headers=headers, ) - except ValueError as e: - logger.error(f"Request parsing error: {e}") - return self.create_error_response(str(e)) + except ValidationError as exc: + return self._render_pydantic_validation_error(exc) except Exception as e: logger.error(traceback.format_exc()) return self.create_error_response( @@ -119,65 +205,83 @@ async def _parse_video_generation_request( self, raw_request: Request, ) -> VideoGenerationRequest: - """Parse video generation request from either JSON or multipart/form-data. - - Supports both: - - application/json: Standard JSON request with VideoGenerationRequest model - - multipart/form-data: Form fields + file upload for input_reference + """Parse a video generation request from JSON or multipart form data. + + Both content types funnel through ``VideoGenerationRequest`` for + final validation so the wire contract is identical on either + path: unknown top-level fields are rejected by + ``extra="forbid"``, the paired ``width``/``height`` validator + runs, and the ``fps`` alias is honored via the model's + ``populate_by_name=True`` config. + + Multipart payloads come in as strings; Pydantic coerces them to + the declared field types. ``extra_params`` accepts a + JSON-encoded object as its string form so multipart callers can + pass model-specific knobs. """ content_type = raw_request.headers.get("content-type", "") if "application/json" in content_type: - # Parse as JSON using Pydantic model body = await raw_request.json() return VideoGenerationRequest(**body) if "multipart/form-data" in content_type: - # Parse multipart/form-data manually form = await raw_request.form() - - # Extract all fields and convert to proper types data = {} + for key in form: + value = form[key] + if hasattr(value, "file"): + # Uploaded file (``input_reference``) — pass through + # so the conversion layer reads ``.file``. + data[key] = value + continue + if key == "extra_params": + if value == "": + continue + try: + data[key] = json.loads(value) + except json.JSONDecodeError as exc: + raise ValueError( + f"'extra_params' must be a JSON object string; {exc}" + ) from exc + continue + if value == "": + continue + data[key] = value + return VideoGenerationRequest(**data) - # Required field - if "prompt" in form: - data["prompt"] = form["prompt"] - else: - raise ValueError("'prompt' is required") - - # Optional string fields - for field in ["model", "size", "negative_prompt", "output_format"]: - if field in form and form[field]: - data[field] = form[field] - - # Optional numeric fields - if "seconds" in form and form["seconds"]: - data["seconds"] = float(form["seconds"]) - if "fps" in form and form["fps"]: - data["fps"] = int(form["fps"]) - if "n" in form and form["n"]: - data["n"] = int(form["n"]) - if "num_inference_steps" in form and form["num_inference_steps"]: - data["num_inference_steps"] = int(form["num_inference_steps"]) - if "guidance_scale" in form and form["guidance_scale"]: - data["guidance_scale"] = float(form["guidance_scale"]) - if "guidance_rescale" in form and form["guidance_rescale"]: - data["guidance_rescale"] = float(form["guidance_rescale"]) - if "seed" in form and form["seed"]: - data["seed"] = int(form["seed"]) - - # Handle file upload for input_reference - if "input_reference" in form: - input_ref = form["input_reference"] - if hasattr(input_ref, "file"): # It's an UploadFile - data["input_reference"] = input_ref + raise ValueError( + f"Unsupported content-type: {content_type}. Use 'application/json' or 'multipart/form-data'" + ) - return VideoGenerationRequest(**data) + def _render_pydantic_validation_error(self, exc: ValidationError) -> Response: + """Render a multipart Pydantic ``ValidationError`` as the LLM envelope. - else: - raise ValueError( - f"Unsupported content-type: {content_type}. Use 'application/json' or 'multipart/form-data'" - ) + The visual-gen-scoped 422 envelope is the same shape JSON requests + get from the FastAPI ``RequestValidationError`` handler, so JSON + and multipart clients see indistinguishable bodies on bad + payloads. + """ + parts: list[str] = [] + for err in exc.errors(): + loc = ".".join(str(seg) for seg in err.get("loc", ()) if seg != "body") + etype = err.get("type", "") + msg = err.get("msg", "") + if etype == "extra_forbidden": + parts.append( + f"Unknown request field {loc!r}. Pass model-specific " + "parameters via 'extra_params' instead." + ) + elif loc: + parts.append(f"{loc}: {msg}") + else: + parts.append(msg) + message = "; ".join(parts) if parts else str(exc) + return self.create_error_response( + message=message, + err_type="BadRequestError", + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + ) async def openai_video_generation_async( self, @@ -201,6 +305,19 @@ async def openai_video_generation_async( params = parse_visual_gen_params( request, video_id, self.generator, media_storage_path=str(self.media_storage_path) ) + # Synchronously validate the resolved params against the + # loaded pipeline's extra-param specs / declared defaults + # so unknown ``extra_params`` keys and similar engine-side + # rejections surface as HTTP 400 here instead of becoming + # a queued job whose background task later fails. + from tensorrt_llm.visual_gen.params import validate_visual_gen_params + + validate_visual_gen_params( + params, + declared_defaults=self.generator.executor.default_generation_params, + extra_param_specs=self.generator.executor.extra_param_specs, + ) + _preflight_encoder_format(request.format) logger.info( f"Generating video: {video_id} with params: {params} and prompt: {request.prompt}" ) @@ -214,8 +331,9 @@ async def openai_video_generation_async( prompt=request.prompt, status="queued", duration=request.seconds, - fps=request.fps, + fps=params.frame_rate, size=f"{params.width}x{params.height}", + response_format=request.response_format, ) await VIDEO_STORE.upsert(video_id, video_job) @@ -232,9 +350,11 @@ async def openai_video_generation_async( return JSONResponse(content=video_job.model_dump(), status_code=202) + except ValidationError as exc: + return self._render_pydantic_validation_error(exc) except ValueError as e: - logger.error(f"Request parsing error: {e}") - return self.create_error_response(str(e)) + logger.error(f"Async video request error: {e}") + return self.create_error_response(str(e), status_code=HTTPStatus.BAD_REQUEST) except Exception as e: logger.error(traceback.format_exc()) return self.create_error_response( @@ -251,9 +371,6 @@ async def _generate_video_background( ): """Background task to generate video and save to storage.""" try: - # Resolve the video encode format (mp4/avi/auto) - resolved_fmt, _ = resolve_video_format(request.output_format) - background_start = time.perf_counter() future = self.generator.generate_async(inputs=request.prompt, params=params) output = await future @@ -268,14 +385,25 @@ async def _generate_video_background( await VIDEO_STORE.upsert(video_id, job) return - # Save all generated videos (batch-aware). - batch_size = output.video.shape[0] if output.video.dim() == 5 else 1 - paths_in = [self.media_storage_path / f"{video_id}_{i}" for i in range(batch_size)] - saved_paths = output.save( - paths_in, - format=resolved_fmt, - frame_rate=output.frame_rate or request.fps or params.frame_rate, - ) + if is_tensor_format(request.format): + # One tensor file per batch item, mirroring the encoder + # path; the async job records all paths on + # ``output_paths`` so subsequent GETs can find each item. + batch_size = output.video.shape[0] if output.video.dim() == 5 else 1 + tensor_paths = [ + self.media_storage_path / f"{video_id}_{i}.{request.format}" + for i in range(batch_size) + ] + saved_paths = output.save(tensor_paths, format=request.format) + else: + resolved_fmt, _ = resolve_video_format(request.format) + batch_size = output.video.shape[0] if output.video.dim() == 5 else 1 + paths_in = [self.media_storage_path / f"{video_id}_{i}" for i in range(batch_size)] + saved_paths = output.save( + paths_in, + format=resolved_fmt, + frame_rate=output.frame_rate or request.frame_rate or params.frame_rate, + ) latency = time.perf_counter() - background_start # seconds metrics = output.metrics generation = metrics.generation if metrics is not None else 0.0 @@ -396,20 +524,36 @@ async def get_video_content(self, video_id: str, raw_request: Request) -> Respon status_code=HTTPStatus.BAD_REQUEST, ) - # Try to use stored output path, otherwise check for both .mp4 and .avi + # Use the stored output path when present, otherwise probe the + # well-known output suffixes for this video_id — try both the + # bare ``{vid}{ext}`` and the batch-indexed ``{vid}_0{ext}`` + # names, matching the convention ``delete_video`` uses. video_path = None if job.output_path and os.path.exists(job.output_path): video_path = Path(job.output_path) else: - # Fall back to checking common extensions - for ext in [".mp4", ".avi"]: - candidate = self.media_storage_path / f"{video_id}{ext}" - if os.path.exists(candidate): - video_path = candidate + for ext in _KNOWN_VIDEO_OUTPUT_SUFFIXES: + for name in (f"{video_id}{ext}", f"{video_id}_0{ext}"): + candidate = self.media_storage_path / name + if os.path.exists(candidate): + video_path = candidate + break + if video_path is not None: break if video_path and os.path.exists(video_path): - media_type = "video/mp4" if video_path.suffix == ".mp4" else "video/x-msvideo" + suffix = video_path.suffix.lstrip(".") + # When the original ``POST /v1/videos`` requested + # ``response_format="b64_json"``, return the bytes + # as a base64 envelope so the async transport + # matches what the sync route does for the same + # ``response_format``. + if job.response_format == "b64_json": + return _b64_json_video_response(video_id, suffix, video_path) + if is_tensor_format(suffix): + media_type = "application/octet-stream" + else: + media_type = _video_content_type(video_path.suffix) return FileResponse( video_path, media_type=media_type, @@ -473,7 +617,7 @@ async def delete_video(self, video_id: str, raw_request: Request) -> Response: else: # Fall back to checking common extensions for either the # single-file name or the batch-indexed name. - for ext in [".mp4", ".avi"]: + for ext in _KNOWN_VIDEO_OUTPUT_SUFFIXES: for name in (f"{video_id}{ext}", f"{video_id}_0{ext}"): candidate = self.media_storage_path / name if os.path.exists(candidate): diff --git a/tensorrt_llm/serve/visual_gen_utils.py b/tensorrt_llm/serve/visual_gen_utils.py index 0a454a726210..3094bf66cb04 100644 --- a/tensorrt_llm/serve/visual_gen_utils.py +++ b/tensorrt_llm/serve/visual_gen_utils.py @@ -4,59 +4,155 @@ import shutil from typing import Any, Dict, List, Optional -from tensorrt_llm.serve.openai_protocol import ( - ImageEditRequest, - ImageGenerationRequest, - VideoGenerationRequest, -) +from tensorrt_llm.logger import logger +from tensorrt_llm.serve.openai_protocol import ImageGenerationRequest, VideoGenerationRequest from tensorrt_llm.visual_gen import VisualGen, VisualGenParams +# Per-field warnings for OpenAI-shaped knobs that the engine has no +# semantic for. Each entry maps the request attribute to the message +# logged when the client sends a non-None value. +_NO_SEMANTIC_FIELD_WARNINGS: Dict[str, str] = { + "quality": ( + "Request field 'quality' accepted for OpenAI-SDK compatibility but " + "ignored; pass 'num_inference_steps' for explicit step control." + ), + "style": ( + "Request field 'style' accepted for OpenAI-SDK compatibility but " + "ignored; the engine has no equivalent semantic." + ), +} + + +def _warn_if_set_with_no_semantic( + request: ImageGenerationRequest | VideoGenerationRequest, + loaded_model_id: Optional[str] = None, +) -> None: + """Log WARNING for OpenAI-shape fields the engine cannot honor. + + ``model`` is warn-on-mismatch (trtllm-serve is single-model per + process). ``quality`` and ``style`` are warn-on-set. ``user`` is + accepted silently — it's an OpenAI trace field with no engine + semantic and keeps request logs clean. + """ + for field, message in _NO_SEMANTIC_FIELD_WARNINGS.items(): + if getattr(request, field, None) is not None: + logger.warning(message) + model_value = getattr(request, "model", None) + if model_value is not None and loaded_model_id is not None and model_value != loaded_model_id: + logger.warning( + "Request field 'model'=%r does not match the loaded model " + "%r; the model field is logged but ignored.", + model_value, + loaded_model_id, + ) + + +def _merge_extra_params( + params: VisualGenParams, + request_extras: Optional[Dict[str, Any]], + extra_param_specs: Dict[str, Any], +) -> None: + """Shallow-merge request ``extra_params`` into ``params.extra_params``. + + Pipeline defaults are already populated in ``params.extra_params`` + by ``generator.default_params``. Per-key behavior: + + - Known key + non-null value: override the default. + - Known key + ``null`` value: keep the pipeline default. The + pre-seeded default already encodes the right state; do not pop + so pipelines that genuinely distinguish ``None`` from "absent" + see the same value they would for a client that omitted the key. + - Unknown key + any value (including ``null``): pass through to + ``params.extra_params`` so the executor's strict-key validation + raises ``unknown_extra_param``. This is the key guarantee + against silent typos — schema-blind null stripping would let + ``{"stg_sclae": null}`` produce a 200 with retained defaults. + + When the request supplies no extras and the pipeline declared + none either, the params dict is normalized to ``None`` to match + the convention that "no extras" is the absence of the dict. + """ + if request_extras: + if params.extra_params is None: + params.extra_params = {} + for key, value in request_extras.items(): + if key in extra_param_specs and value is None: + continue + params.extra_params[key] = value + + if not params.extra_params: + params.extra_params = None + def parse_visual_gen_params( - request: ImageGenerationRequest | VideoGenerationRequest | ImageEditRequest, + request: ImageGenerationRequest | VideoGenerationRequest, id: str, generator: VisualGen, media_storage_path: Optional[str] = None, ) -> VisualGenParams: - # Start from the pipeline's resolved defaults so unspecified request - # fields keep the model's defaults instead of being overwritten with None. + """Translate an HTTP request into :class:`VisualGenParams`. + + Starts from ``generator.default_params`` (already populated with + pipeline-level defaults plus per-key ``extra_params`` defaults) and + overlays only the fields the client sent with a non-``None`` value. + The HTTP layer never invents a default. Validation lives elsewhere: + Pydantic at the request boundary (422), this helper for translation + errors (400 via ``ValueError``), and the executor's + ``validate_visual_gen_params`` for ``extra_params`` + strict-key/type/range checks (400 via ``ValueError``). + """ params = generator.default_params - if params.extra_params is None: - params.extra_params = {} + # Resolution: structured (width + height) wins over the OpenAI-shaped + # ``size`` string. Sending exactly one of {width, height} is rejected + # at the Pydantic boundary by the request's model_validator. + if request.width is not None and request.height is not None: + params.width, params.height = request.width, request.height + elif request.size is not None and request.size != "auto": + params.width, params.height = map(int, request.size.split("x")) + + # Universal per-request overlays — each guard is the "do not + # override with None" rule in action. if request.negative_prompt is not None: params.negative_prompt = request.negative_prompt - if request.size is not None and request.size != "auto": - params.width, params.height = map(int, request.size.split("x")) + if request.num_inference_steps is not None: + params.num_inference_steps = request.num_inference_steps if request.guidance_scale is not None: params.guidance_scale = request.guidance_scale - if request.guidance_rescale is not None: - params.extra_params["guidance_rescale"] = request.guidance_rescale - - if isinstance(request, (ImageGenerationRequest, ImageEditRequest)): - if request.num_inference_steps is not None: - params.num_inference_steps = request.num_inference_steps - elif isinstance(request, ImageGenerationRequest) and request.quality == "hd": - params.num_inference_steps = 30 + if request.max_sequence_length is not None: + params.max_sequence_length = request.max_sequence_length + if request.seed is not None: + params.seed = int(request.seed) + + if isinstance(request, ImageGenerationRequest): if request.n is not None: params.num_images_per_prompt = request.n - if isinstance(request, ImageEditRequest): - if request.image is not None: - if isinstance(request.image, list): - params.image = [base64.b64decode(image) for image in request.image] - else: - params.image = [base64.b64decode(request.image)] - if request.mask is not None: - if isinstance(request.mask, list): - params.mask = [base64.b64decode(mask) for mask in request.mask] - else: - params.mask = base64.b64decode(request.mask) elif isinstance(request, VideoGenerationRequest): - if request.num_inference_steps is not None: - params.num_inference_steps = request.num_inference_steps - if request.n is not None: - params.num_images_per_prompt = request.n + if request.frame_rate is not None: + params.frame_rate = request.frame_rate + # num_frames wins; otherwise derive from seconds * frame_rate + # (using whichever frame_rate is now in effect on params). + if request.num_frames is not None: + params.num_frames = request.num_frames + elif request.seconds is not None: + if params.frame_rate is None: + raise ValueError( + f"Cannot derive 'num_frames' from seconds={request.seconds}: " + "neither the request nor the loaded pipeline declares a " + "'frame_rate'. Pass 'fps' / 'frame_rate' alongside " + "'seconds', or pass 'num_frames' directly." + ) + derived = int(request.seconds * params.frame_rate) + if derived < 1: + raise ValueError( + f"Derived frame count is {derived} (from seconds=" + f"{request.seconds} * frame_rate={params.frame_rate}); " + "at least 1 frame is required. Pass a larger 'seconds' " + "value, a larger 'fps' / 'frame_rate', or 'num_frames' " + "directly." + ) + params.num_frames = derived if request.input_reference is not None: if media_storage_path is None: raise ValueError("media_storage_path is required when input_reference is provided") @@ -69,16 +165,8 @@ def parse_visual_gen_params( shutil.copyfileobj(request.input_reference.file, f) params.image = ref_path - params.frame_rate = request.fps - params.num_frames = int(request.seconds * request.fps) - - if request.seed is not None: - params.seed = int(request.seed) - - # Drop extra_params if we didn't end up with any — matches VisualGenParams - # convention where None means "no extras" for pipelines that declare none. - if not params.extra_params: - params.extra_params = None + _warn_if_set_with_no_semantic(request, getattr(generator, "model", None)) + _merge_extra_params(params, request.extra_params, generator.extra_param_specs) return params diff --git a/tensorrt_llm/visual_gen/output.py b/tensorrt_llm/visual_gen/output.py index dca2efec70ca..0660e661122f 100644 --- a/tensorrt_llm/visual_gen/output.py +++ b/tensorrt_llm/visual_gen/output.py @@ -18,6 +18,30 @@ from tensorrt_llm.llmapi.utils import set_api_status +def _infer_format_from_path( + path: Union[str, Path, List[Union[str, Path]]], +) -> Optional[str]: + """Return the tensor format implied by *path*'s suffix, or ``None``. + + For a list of paths, every entry must share the same recognized + tensor suffix; mixed or unrecognized suffixes return ``None`` and + let the image/video encoder dispatch handle them. + """ + from tensorrt_llm.media.tensor_payload import TENSOR_FORMATS + + def _suffix_format(p) -> Optional[str]: + suffix = Path(p).suffix + fmt = suffix[1:] if suffix.startswith(".") else suffix + return fmt if fmt in TENSOR_FORMATS else None + + if isinstance(path, list): + if not path: + return None + formats = {_suffix_format(p) for p in path} + return next(iter(formats)) if len(formats) == 1 else None + return _suffix_format(path) + + @set_api_status("prototype") @dataclass class VisualGenMetrics: @@ -86,18 +110,19 @@ def save( audio_sample_rate: Optional[int] = None, quality: int = 95, ) -> Union[Path, List[Path]]: - """Encode this output to disk via :mod:`tensorrt_llm.media.encoding`. + """Encode this output to disk. Args: path: Where to write. A single :class:`str`/:class:`pathlib.Path` writes one file (batched tensors collapse to the first - slice); a list of paths writes one file per batch item via - :func:`~tensorrt_llm.media.encoding.save_images` / - :func:`~tensorrt_llm.media.encoding.save_videos`. In both - cases format is inferred from the extension unless - ``format`` is given. - format: Explicit format override (``'png'``/``'jpg'``/``'webp'`` - for images, ``'mp4'``/``'avi'`` for video). + slice); a list of paths writes one file per batch item. + Format is inferred from the extension unless ``format`` + is given. + format: Explicit format. Image encoders: ``"png"``, ``"jpg"``, + ``"webp"``. Video encoders: ``"mp4"``, ``"avi"``. Tensor + payloads: ``"safetensors"``, ``"pt"`` — these carry every + populated modality (image/video/audio) plus scalar + metadata (frame_rate, audio_sample_rate) in one file. frame_rate: Override the frame rate for video output. Defaults to ``self.frame_rate`` when not provided. audio_sample_rate: Override the audio sample rate. Defaults to @@ -113,9 +138,11 @@ def save( ValueError: When video output lacks a frame rate, when the output carries no media tensor at all, or when the list length does not match the batch size. - NotImplementedError: When the output is audio-only. + NotImplementedError: When the output is audio-only and a + non-tensor format is requested. """ from tensorrt_llm.media.encoding import save_image, save_images, save_video, save_videos + from tensorrt_llm.media.tensor_payload import is_tensor_format if self.error is not None: raise RuntimeError( @@ -124,6 +151,21 @@ def save( is_batch = isinstance(path, list) + # Tensor formats carry every populated modality in one payload, + # so the dispatch table for image/video/audio below does not + # apply. When ``format`` is omitted, infer it from the path + # suffix so callers using the documented extension convention + # (``out.safetensors``/``out.pt``) reach the tensor path. + resolved_format = format if format is not None else _infer_format_from_path(path) + if is_tensor_format(resolved_format): + return self._save_tensor_payload( + path, + resolved_format, + is_batch=is_batch, + frame_rate=frame_rate, + audio_sample_rate=audio_sample_rate, + ) + if self.image is not None: if is_batch: saved_list = save_images( @@ -169,3 +211,99 @@ def save( f"Cannot save output: request {self.request_id} carries no media " "(image/video/audio are all None)." ) + + def _save_tensor_payload( + self, + path: Union[str, Path, List[Union[str, Path]]], + fmt: str, + *, + is_batch: bool, + frame_rate: Optional[float] = None, + audio_sample_rate: Optional[int] = None, + ) -> Union[Path, List[Path]]: + """Write the safetensors/pt payload for this output to *path*. + + A single path writes one logical output: when the populated + media tensor is batched the payload corresponds to the first + item, matching the image/video encoder paths + (:func:`~tensorrt_llm.media.encoding.save_image` / + :func:`~tensorrt_llm.media.encoding.save_video`). A list of + paths writes one payload per batch item by slicing the + populated tensors along their leading batch axis. + + ``frame_rate`` and ``audio_sample_rate`` override the + corresponding fields on ``self`` when present, matching the + encoder path's override semantics. + """ + from tensorrt_llm.media.tensor_payload import ( + infer_batch_size, + save_visual_gen_output_payload, + ) + + batch_size = infer_batch_size(self) + + if not is_batch: + if batch_size > 1: + raise ValueError( + f"save received a single path but the output carries a batched " + f"tensor of size {batch_size}; pass a list of {batch_size} paths " + "(one per item)." + ) + slice_index = 0 if batch_size > 0 else None + return save_visual_gen_output_payload( + self, + path, + fmt, + batch_index=slice_index, + frame_rate=frame_rate, + audio_sample_rate=audio_sample_rate, + ) + + if len(path) != batch_size: + raise ValueError( + f"Number of paths ({len(path)}) does not match batch size ({batch_size})." + ) + return [ + save_visual_gen_output_payload( + self, + p, + fmt, + batch_index=i, + frame_rate=frame_rate, + audio_sample_rate=audio_sample_rate, + ) + for i, p in enumerate(path) + ] + + def _save_bytes( + self, + format: str, + *, + batch_index: Optional[int] = None, + frame_rate: Optional[float] = None, + audio_sample_rate: Optional[int] = None, + ) -> bytes: + """Serialize this output to bytes for in-memory transport. + + Internal counterpart to :meth:`save`. The public output API + exposes only :meth:`save` (file-based); the in-memory bytes + path is reserved for trtllm-serve's ``b64_json`` transport, + which derives ``batch_index`` from + :func:`tensorrt_llm.media.tensor_payload.infer_batch_size` + before iterating. Only tensor formats are supported today. + """ + from tensorrt_llm.media.tensor_payload import is_tensor_format, serialize_visual_gen_output + + if self.error is not None: + raise RuntimeError( + f"Cannot save output: request {self.request_id} failed with error: {self.error}" + ) + if not is_tensor_format(format): + raise ValueError(f"_save_bytes supports only tensor formats today; got {format!r}.") + return serialize_visual_gen_output( + self, + format, + batch_index=batch_index, + frame_rate=frame_rate, + audio_sample_rate=audio_sample_rate, + ) diff --git a/tensorrt_llm/visual_gen/params.py b/tensorrt_llm/visual_gen/params.py index 87754a56c956..9e71a25d0505 100644 --- a/tensorrt_llm/visual_gen/params.py +++ b/tensorrt_llm/visual_gen/params.py @@ -46,7 +46,20 @@ class VisualGenParams(StrictBaseModel): max_sequence_length: Optional[int] = Field( default=None, description="Max tokens for text encoding." ) - seed: int = Field(default=42, description="Random seed for reproducibility.") + # When ``num_images_per_prompt > 1`` is honored end-to-end (future), + # the implementation follows the diffusers/vllm-omni convention: + # one ``torch.Generator(seed=s)`` drives ``N`` latents from a single + # RNG stream (batched ``randn``), not SGLang's per-image + # ``[s, s+1, …]`` expansion. Adding ``seed: int | list[int]`` is + # left as an additive extension if explicit per-image seeds become + # a requirement. + seed: Optional[int] = Field( + default=None, + description=( + "Random seed for reproducibility. ``None`` means the engine draws " + "a fresh seed on the coordinator rank before pipeline dispatch." + ), + ) # Video num_frames: Optional[int] = Field( @@ -59,12 +72,6 @@ class VisualGenParams(StrictBaseModel): image: Optional[Union[str, bytes, List[Union[str, bytes]]]] = Field( default=None, description="Reference image(s) for I2V/I2I." ) - mask: Optional[Union[str, bytes, List[bytes]]] = Field( - default=None, description="Inpainting mask path or raw bytes." - ) - image_cond_strength: Optional[float] = Field( - default=None, description="Image conditioning strength." - ) # Per-prompt multiplier num_images_per_prompt: int = Field(default=1, description="Number of images per prompt.") @@ -75,3 +82,107 @@ class VisualGenParams(StrictBaseModel): description="Model-specific parameters. Use VisualGen.extra_param_specs " "to discover valid keys for the loaded pipeline.", ) + + +# Python type name → accepted Python types for ``ExtraParamSchema`` validation. +# The validator duck-types ``ExtraParamSchema`` via ``spec.type`` / ``spec.range`` +# so it does not need to import the (internal) schema class. +_TYPE_MAP = { + "float": (float, int), + "int": (int,), + "bool": (bool,), + "str": (str,), + "list": (list,), +} + +# Generation config fields that pipelines declare defaults for. If a user +# sets one of these but the pipeline doesn't declare it in +# ``default_generation_params``, the request is rejected so unsupported +# knobs don't get silently dropped. Conditioning inputs ``image`` and +# ``negative_prompt`` are validated at runtime by the pipeline's +# ``infer()`` and stay out of this set. +_GENERATION_CONFIG_FIELDS: tuple = ( + "height", + "width", + "num_inference_steps", + "guidance_scale", + "max_sequence_length", + "num_frames", + "frame_rate", +) + + +def validate_visual_gen_params( + params: VisualGenParams, + *, + declared_defaults: Optional[Dict[str, Any]], + extra_param_specs: Dict[str, Any], +) -> None: + """Validate *params* against pipeline-declared defaults and extra specs. + + Called on the coordinator side at :meth:`VisualGen.generate_async` + entry (and again as a pre-flight check by the async video route, so + a malformed request becomes HTTP 400 before the job is queued). + Raises :class:`ValueError` with a multi-line message listing every + violation when one or more of: + + - Unknown ``extra_params`` keys. + - Universal fields (e.g. ``num_frames``) set by the user but not + declared in ``declared_defaults``. Skipped when ``declared_defaults`` + is ``None`` — clients that don't carry the per-pipeline universal + field set can still validate ``extra_params``. + - Type mismatches for ``extra_params`` values. + - Out-of-range ``extra_params`` values. + """ + messages: List[str] = [] + specs = extra_param_specs + + # --- unknown extra_params keys --- + if params.extra_params: + unknown = sorted(set(params.extra_params.keys()) - set(specs.keys())) + if unknown: + messages.append(f"Unknown extra_params {unknown}. Supported: {sorted(specs.keys())}") + + # --- unsupported universal fields --- + # Check generation config fields the user explicitly set (not None) + # that the loaded pipeline never declared in declared_defaults. + # Conditioning inputs (image, negative_prompt) are excluded — they + # are validated at runtime by the pipeline's infer(). + if declared_defaults is not None: + for field_name in _GENERATION_CONFIG_FIELDS: + value = getattr(params, field_name, None) + if value is not None and field_name not in declared_defaults: + messages.append( + f"Parameter '{field_name}' is set but the loaded " + f"pipeline does not accept it (not in default_generation_params)." + ) + + # --- extra_params type and range checks --- + if params.extra_params: + for key, value in params.extra_params.items(): + if key not in specs: + continue # already reported as unknown above + spec = specs[key] + # Skip None values (param left at its None default) + if value is None: + continue + # Type check + expected_types = _TYPE_MAP.get(spec.type) + if expected_types and not isinstance(value, expected_types): + messages.append( + f"extra_params['{key}'] expected type '{spec.type}', " + f"got {type(value).__name__}: {value!r}" + ) + continue # skip range check if type is wrong + # Range check (numeric only) + if spec.range is not None and isinstance(value, (int, float)): + lo, hi = spec.range + if not (lo <= value <= hi): + messages.append( + f"extra_params['{key}'] value {value} is out of range [{lo}, {hi}]" + ) + + if not messages: + return + + raise ValueError("Parameter validation failed:\n" + "\n".join(f" - {e}" for e in messages)) diff --git a/tensorrt_llm/visual_gen/visual_gen.py b/tensorrt_llm/visual_gen/visual_gen.py index 3073f708adf7..c59a140418cf 100644 --- a/tensorrt_llm/visual_gen/visual_gen.py +++ b/tensorrt_llm/visual_gen/visual_gen.py @@ -15,28 +15,24 @@ import asyncio import atexit import itertools -import os -import queue -import socket +import secrets import sys -import threading -import time -import traceback import weakref from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union - -import torch.multiprocessing as mp -import zmq +from typing import Any, Dict, List, Literal, Optional, Union from tensorrt_llm._torch.visual_gen import DiffusionRequest, DiffusionResponse -from tensorrt_llm._torch.visual_gen.executor import run_diffusion_worker +from tensorrt_llm._torch.visual_gen.executor import ( + DiffusionRemoteClient, + _detect_external_launch, + run_diffusion_worker, +) from tensorrt_llm._torch.visual_gen.output import split_visual_gen_output, to_visual_gen_output from tensorrt_llm._torch.visual_gen.pipeline import ExtraParamSchema from tensorrt_llm._torch.visual_gen.pipeline_registry import PIPELINE_REGISTRY, AutoPipeline from tensorrt_llm.visual_gen.args import VisualGenArgs from tensorrt_llm.visual_gen.output import VisualGenOutput -from tensorrt_llm.visual_gen.params import VisualGenParams +from tensorrt_llm.visual_gen.params import VisualGenParams, validate_visual_gen_params __all__ = [ "VisualGen", @@ -44,508 +40,9 @@ "ExtraParamSchema", "VisualGenResult", ] -from tensorrt_llm.executor.ipc import ZeroMqQueue from tensorrt_llm.llmapi.utils import set_api_status from tensorrt_llm.logger import logger -# Timeouts (seconds) -POLL_TIMEOUT = 0.01 -AWAIT_TIMEOUT = 0.05 -THREAD_TIMEOUT = 5.0 -WORKER_TIMEOUT = 2.0 - - -def find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - -def get_ip_address() -> str: - """Get local IP address.""" - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - s.connect(("10.255.255.255", 1)) - return s.getsockname()[0] - except Exception: - return "127.0.0.1" - finally: - s.close() - - -def _detect_external_launch() -> Optional[Tuple[int, int, int, str, int]]: - """Detect whether the process was launched by an external distributed launcher. - - Checks for torchrun (``RANK`` + ``WORLD_SIZE``) and then SLURM - (``SLURM_PROCID`` + ``SLURM_NTASKS``). Returns a - ``(rank, local_rank, world_size, master_addr, master_port)`` tuple when a - multi-process launcher is detected (world_size > 1), or ``None`` for - single-process / single-node ``mp.Process`` mode. - """ - # torchrun / torchelastic sets RANK and WORLD_SIZE - if "RANK" in os.environ and "WORLD_SIZE" in os.environ: - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - if world_size > 1: - local_rank = int(os.environ.get("LOCAL_RANK", rank)) - master_addr = os.environ.get("MASTER_ADDR") - if master_addr is None: - raise RuntimeError( - "MASTER_ADDR must be set for multi-node torchrun runs. " - "Add --master-addr= to your torchrun command, or set " - "MASTER_ADDR in the environment before launching." - ) - master_port = int(os.environ.get("MASTER_PORT", 29500)) - return rank, local_rank, world_size, master_addr, master_port - - # SLURM: srun --ntasks-per-node=GPUS_PER_NODE sets SLURM_PROCID / SLURM_NTASKS - if "SLURM_PROCID" in os.environ and "SLURM_NTASKS" in os.environ: - rank = int(os.environ["SLURM_PROCID"]) - world_size = int(os.environ["SLURM_NTASKS"]) - if world_size > 1: - local_rank = int(os.environ.get("SLURM_LOCALID", rank)) - master_addr = os.environ.get("MASTER_ADDR") - if master_addr is None: - raise RuntimeError( - "MASTER_ADDR must be set for multi-node SLURM runs. " - "Add to your sbatch script:\n" - " MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -1)" - ) - master_port = int(os.environ.get("MASTER_PORT", 29500)) - return rank, local_rank, world_size, master_addr, master_port - - return None - - -class DiffusionRemoteClient: - """Client proxy for remote DiffusionExecutor in worker processes. - - Supports two launch modes: - - **Single-node (default)** - ``VisualGen`` is called from an ordinary Python script. - ``DiffusionRemoteClient`` spawns all worker processes locally via - ``mp.Process`` with ``master_addr=127.0.0.1``. - - **Multi-node (external launcher)** - The script is launched by ``torchrun`` or ``srun --ntasks-per-node=GPUS``. - Each rank runs the same script; ``RANK`` / ``WORLD_SIZE`` / ``MASTER_ADDR`` - / ``MASTER_PORT`` are already set in the environment. - - - Rank 0: becomes the request coordinator. It creates the ZMQ server - sockets and starts its own worker in a background thread, then returns - to the caller so the user script can call ``generate()``. - - Rank > 0: handled by ``VisualGen.__init__`` before this class is - instantiated — they call ``run_diffusion_worker`` directly and exit - via ``sys.exit(0)``. These ranks never reach ``DiffusionRemoteClient``. - """ - - def __init__( - self, - args: VisualGenArgs, - ): - self.args = args - self.n_workers = args.parallel_config.n_workers - - # --- Detect external launcher (torchrun / srun) --- - ext = _detect_external_launch() - - if ext is None: - # Single-node: coordinator spawns all workers locally - # Setup distributed env - self.master_addr = "127.0.0.1" - self.master_port = find_free_port() - - # Setup IPC addresses - self.host_ip = get_ip_address() - req_port, resp_port = find_free_port(), find_free_port() - - self.request_queue_addr = f"tcp://0.0.0.0:{req_port}" - self.response_queue_addr = f"tcp://0.0.0.0:{resp_port}" - self.req_addr_connect = f"tcp://{self.host_ip}:{req_port}" - self.resp_addr_connect = f"tcp://{self.host_ip}:{resp_port}" - - else: - # rank == 0 guaranteed — ranks 1..N-1 exited in VisualGen.__init__ - rank, local_rank, world_size, master_addr, master_port = ext - req_port = find_free_port() - resp_port = find_free_port() - self.master_addr = master_addr - self.master_port = master_port - self.request_queue_addr = f"tcp://0.0.0.0:{req_port}" - self.response_queue_addr = f"tcp://0.0.0.0:{resp_port}" - self.req_addr_connect = f"tcp://{master_addr}:{req_port}" - self.resp_addr_connect = f"tcp://{master_addr}:{resp_port}" - - # Generate shared HMAC keys for IPC authentication - self.req_hmac_key = os.urandom(32) - self.resp_hmac_key = os.urandom(32) - - # IPC setup - self.requests_ipc = None - self.responses_ipc = None - self.pending_requests = queue.Queue() - self.completed_responses: Dict[int, DiffusionResponse] = {} - # Request ids the caller has given up on (e.g., aresult timed out). - # _store_response drops late-arriving responses for these ids so a - # full PipelineOutput tensor does not pin in completed_responses for - # the process lifetime. - self._abandoned_request_ids: Set[int] = set() - - # We'll create asyncio primitives in the background thread's event loop - self._event_loop = None - self.response_event = None - self.lock = None - self.shutdown_event = threading.Event() - self.event_loop_ready = threading.Event() - - # Start background thread (it will create its own event loop) - self.background_thread = threading.Thread(target=self._serve_forever_thread, daemon=True) - self.background_thread.start() - - # Wait for the background thread to initialize the event loop - self.event_loop_ready.wait() - - # Pipeline metadata — populated by _wait_ready from the READY signal. - self.default_generation_params: Dict = {} - self.extra_param_specs: Dict = {} - - # --- Launch workers --- - self.worker_processes = [] - self._ext_worker_thread: Optional[threading.Thread] = None - - if ext is None: - logger.info(f"DiffusionClient: Launching {self.n_workers} workers") - ctx = mp.get_context("spawn") - for rank in range(self.n_workers): - p = ctx.Process( - target=run_diffusion_worker, - kwargs={ - "rank": rank, - "world_size": self.n_workers, - "master_addr": self.master_addr, - "master_port": self.master_port, - "request_queue_addr": self.req_addr_connect, - "response_queue_addr": self.resp_addr_connect, - "visual_gen_args": self.args, - "req_hmac_key": self.req_hmac_key, - "resp_hmac_key": self.resp_hmac_key, - "log_level": logger.level, - "local_rank": rank, - }, - ) - p.start() - self.worker_processes.append(p) - else: - # External launch: rank 0 runs its own worker in a background thread. - # Other nodes' workers are already running (they were launched by the - # external launcher and will connect to our ZMQ server once it binds). - self._ext_worker_thread = threading.Thread( - target=run_diffusion_worker, - kwargs={ - "rank": rank, - "world_size": self.n_workers, - "master_addr": master_addr, - "master_port": master_port, - "request_queue_addr": self.req_addr_connect, - "response_queue_addr": self.resp_addr_connect, - "visual_gen_args": self.args, - "req_hmac_key": self.req_hmac_key, - "resp_hmac_key": self.resp_hmac_key, - "log_level": logger.level, - "local_rank": local_rank, - }, - daemon=True, - ) - self._ext_worker_thread.start() - - self._wait_ready() - - @staticmethod - def _close_socket(ipc_queue): - if ipc_queue and ipc_queue.socket: - ipc_queue.socket.setsockopt(zmq.LINGER, 0) - ipc_queue.close() - - def enqueue_requests(self, requests: List[DiffusionRequest]) -> List[int]: - """Enqueue requests and return their IDs.""" - req_ids = [] - for req in requests: - self.pending_requests.put(req) - req_ids.append(req.request_id) - return req_ids - - async def await_responses( - self, request_ids: Union[int, List[int]], timeout: Optional[float] = None - ) -> Union[DiffusionResponse, List[DiffusionResponse]]: - """Wait for responses by request IDs. - - Args: - request_ids: Single request ID or list of request IDs to wait for - timeout: Maximum total wait time in seconds (None = wait indefinitely) - - Returns: - Single response or list of responses (None if request timed out) - """ - is_single = isinstance(request_ids, int) - ids = [request_ids] if is_single else request_ids - - start_time = time.time() - results = {} - - while len(results) < len(ids): - async with self.lock: - for req_id in ids: - if req_id in self.completed_responses: - results[req_id] = self.completed_responses.pop(req_id) - - # All responses collected - if len(results) == len(ids): - break - - # Check if overall timeout exceeded - if timeout is not None: - elapsed = time.time() - start_time - if elapsed >= timeout: - break - # Wait for remaining time or AWAIT_TIMEOUT, whichever is shorter - wait_time = min(timeout - elapsed, AWAIT_TIMEOUT) - else: - wait_time = AWAIT_TIMEOUT - - try: - await asyncio.wait_for(self.response_event.wait(), timeout=wait_time) - except asyncio.TimeoutError: - pass - self.response_event.clear() - - out = [results.get(rid) for rid in ids] - return out[0] if is_single else out - - def await_responses_sync( - self, request_ids: Union[int, List[int]], timeout: Optional[float] = None - ) -> Union[DiffusionResponse, List[DiffusionResponse]]: - """Sync wrapper to await responses from the main thread.""" - future = asyncio.run_coroutine_threadsafe( - self.await_responses(request_ids, timeout), self._event_loop - ) - return future.result(timeout=timeout if timeout else None) - - def _init_ipc(self) -> bool: - """Initialize IPC queues.""" - try: - logger.info("DiffusionClient: Initializing IPC") - self.requests_ipc = ZeroMqQueue( - (self.request_queue_addr, self.req_hmac_key), - is_server=True, - socket_type=zmq.PUSH, - use_hmac_encryption=True, - ) - self.responses_ipc = ZeroMqQueue( - (self.response_queue_addr, self.resp_hmac_key), - is_server=True, - socket_type=zmq.PULL, - use_hmac_encryption=True, - ) - logger.info("DiffusionClient: IPC ready") - return True - except Exception as e: - logger.error(f"DiffusionClient: IPC init failed: {e}") - return False - - def _send_shutdown(self): - """Send shutdown signal.""" - logger.info("DiffusionClient: Sending shutdown signal") - if self.requests_ipc: - self.requests_ipc.put(None) - self._close_socket(self.requests_ipc) - - def _process_requests(self): - """Process pending requests.""" - try: - req = self.pending_requests.get(timeout=POLL_TIMEOUT) - if req is None: - self._send_shutdown() - self.shutdown_event.set() - return - - logger.info(f"DiffusionClient: Sending request {req.request_id}") - self.requests_ipc.put(req) - except queue.Empty: - pass - except Exception as e: - logger.error(f"DiffusionClient: Error sending request: {e}") - logger.error(traceback.format_exc()) - - def _process_responses(self): - """Poll and process responses.""" - try: - if self.responses_ipc.poll(timeout=POLL_TIMEOUT): - response = self.responses_ipc.get() - if isinstance(response, DiffusionResponse): - if response.request_id == -1: - logger.info("DiffusionClient: Received READY signal") - - # Schedule the lock acquisition and event setting in the event loop - asyncio.run_coroutine_threadsafe( - self._store_response(response), self._event_loop - ) - except Exception as e: - logger.error(f"DiffusionClient: Error processing response: {e}") - - async def _store_response(self, response: DiffusionResponse): - """Store response in the completed_responses dict (async helper). - - Drops the response if the request id has been abandoned so that - late-arriving responses for timed-out requests do not leak into - ``completed_responses`` for the process lifetime. - """ - async with self.lock: - if response.request_id in self._abandoned_request_ids: - self._abandoned_request_ids.discard(response.request_id) - return - self.completed_responses[response.request_id] = response - self.response_event.set() - - async def abandon_request_id(self, request_id: int): - """Mark a request id as abandoned and drop any cached response. - - Called from the result handle's timeout branch to prevent the - executor from holding a full ``PipelineOutput`` for a request whose - caller has stopped waiting. Handles both orderings: - - - Response already arrived between the timeout firing and the - abandon call → ``pop`` releases it here. - - Response arrives after the abandon call → ``_store_response`` - checks the abandoned set and drops it on arrival. - """ - async with self.lock: - self.completed_responses.pop(request_id, None) - self._abandoned_request_ids.add(request_id) - - def _cleanup_ipc(self): - """Cleanup IPC.""" - logger.info("DiffusionClient: Cleaning up IPC") - self._close_socket(self.requests_ipc) - self._close_socket(self.responses_ipc) - - def _serve_forever_thread(self): - """Background thread wrapper that creates and runs an event loop.""" - logger.info("DiffusionClient: Background thread started") - - # Create a new event loop for this thread - self._event_loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._event_loop) - - # Create async primitives in this thread's event loop - self.response_event = asyncio.Event() - self.lock = asyncio.Lock() - - # Signal that the event loop is ready - self.event_loop_ready.set() - - # Run the async serve_forever - try: - self._event_loop.run_until_complete(self._serve_forever()) - finally: - self._event_loop.close() - logger.info("DiffusionClient: Background thread stopped") - - async def _serve_forever(self): - """Background thread main loop (async version).""" - if not self._init_ipc(): - return - - while not self.shutdown_event.is_set(): - self._process_requests() - self._process_responses() - await asyncio.sleep(0.001) # Yield control to allow other coroutines to run - - self._cleanup_ipc() - - def shutdown(self): - """Shutdown client and workers.""" - logger.info("DiffusionClient: Shutting down") - self.pending_requests.put(None) - - self.background_thread.join(timeout=THREAD_TIMEOUT) - if self.background_thread.is_alive(): - logger.warning("DiffusionClient: Force stopping background thread") - self.shutdown_event.set() - self.background_thread.join(timeout=1.0) - - # Shutdown workers - logger.info("DiffusionClient: Stopping workers") - for p in self.worker_processes: - p.join(timeout=WORKER_TIMEOUT) - if p.is_alive(): - logger.warning(f"DiffusionClient: Terminating worker {p.pid} with SIGTERM") - p.terminate() - p.join(timeout=WORKER_TIMEOUT) - if p.is_alive(): - logger.warning(f"DiffusionClient: Force killing worker {p.pid} with SIGKILL") - p.kill() - p.join(timeout=WORKER_TIMEOUT) - - # External-launch mode: join rank-0 worker thread - if self._ext_worker_thread is not None and self._ext_worker_thread.is_alive(): - self._ext_worker_thread.join(timeout=WORKER_TIMEOUT) - - def _wait_ready(self): - """Wait for workers to be ready (sync wrapper for async operation).""" - logger.info("DiffusionClient: Waiting for workers") - - future = asyncio.run_coroutine_threadsafe(self._wait_ready_async(), self._event_loop) - try: - future.result() - except Exception: - self.shutdown() - raise - - async def _wait_ready_async(self): - """Wait for workers to be ready (async version). - - Polls indefinitely for the ready signal. If any worker process dies - during initialization, raises RuntimeError immediately (LLM-style). - """ - start_time = time.time() - last_log_time = start_time - log_interval = 300 - - while True: - async with self.lock: - if -1 in self.completed_responses: - ready_resp = self.completed_responses.pop(-1) - # Extract pipeline metadata from the READY payload. - payload = ready_resp.output - if isinstance(payload, dict): - self.default_generation_params = payload.get( - "default_generation_params", {} - ) - self.extra_param_specs = payload.get("extra_param_specs", {}) - elapsed = time.time() - start_time - logger.info(f"DiffusionClient: Workers ready ({elapsed:.1f}s)") - return - - worker_dead = any(not p.is_alive() for p in self.worker_processes) - ext_dead = ( - self._ext_worker_thread is not None and not self._ext_worker_thread.is_alive() - ) - if worker_dead or ext_dead: - raise RuntimeError("DiffusionClient: Worker died during initialization") - - now = time.time() - if now - last_log_time >= log_interval: - elapsed = now - start_time - logger.info(f"DiffusionClient: Still waiting for workers ({elapsed:.0f}s elapsed)") - last_log_time = now - - try: - await asyncio.wait_for(self.response_event.wait(), timeout=AWAIT_TIMEOUT) - except asyncio.TimeoutError: - pass - self.response_event.clear() - @set_api_status("prototype") class VisualGenResult: @@ -664,8 +161,12 @@ def _build_resolved(self, response: "DiffusionResponse"): return split_visual_gen_output(response, self._batch_size) def _resolved_value(self): - # For single prompts, surface failure via RuntimeError. For batch, - # return the list as-is so callers can inspect per-item ``error``. + # For single prompts, surface engine-side failure as + # ``RuntimeError``. Request-parameter validation is enforced + # synchronously at :meth:`VisualGen.generate_async` entry, so + # anything reaching this point is by definition a runtime + # failure from ``pipeline.infer()``. For batches, return the + # list as-is so callers iterate per-item ``error``. if self._batch_size is None and isinstance(self._resolved, VisualGenOutput): if self._resolved.error is not None: raise RuntimeError(f"Generation failed: {self._resolved.error}") @@ -880,10 +381,36 @@ def generate_async( # Snapshot caller-provided params so later mutations don't affect # the queued request (the dispatcher thread serializes it lazily). + # When the caller passed no params, materialize a default + # :class:`VisualGenParams` from the loaded pipeline's + # declared defaults + extra-param specs (cached on the executor + # from the READY signal) and skip validation — there's nothing + # user-supplied to validate against. + if params is not None: + resolved_params = params.model_copy(deep=True) + # Raising in the caller's process means ``ValueError`` reaches + # the user as a natural Python exception; the worker only has + # to deal with genuine runtime failures from ``pipeline.infer()``. + validate_visual_gen_params( + resolved_params, + declared_defaults=self.executor.default_generation_params, + extra_param_specs=self.executor.extra_param_specs, + ) + else: + resolved_params = self.default_params + + # Materialize the seed once, here at the public Python boundary, + # so every downstream layer (executor, broadcast, pipeline) sees + # a concrete int. Drawing on the coordinator process and + # broadcasting the resolved value keeps multi-rank parallelism + # (cfg_size, ulysses_size) deterministic. + if resolved_params.seed is None: + resolved_params.seed = secrets.randbits(63) + request = DiffusionRequest( request_id=req_id, prompt=prompt, - params=params.model_copy(deep=True) if params is not None else None, + params=resolved_params, ) self.executor.enqueue_requests([request]) diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index e78984886eee..440fe2885fab 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -114,6 +114,7 @@ l0_a10: - unittest/_torch/visual_gen/test_visual_gen_params.py - unittest/visual_gen/test_output.py - unittest/media/test_encoding.py + - unittest/_torch/visual_gen/test_tensor_payload.py # llmapi - unittest/llmapi/test_llm_utils.py - unittest/llmapi/test_gc_utils.py diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 5fd6b5d6cd83..e9571ad9724a 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -189,6 +189,7 @@ l0_b200: # ------------- Visual Gen tests --------------- - unittest/_torch/visual_gen/test_visual_gen_args.py - unittest/_torch/visual_gen/test_visual_gen_params.py + - unittest/_torch/visual_gen/test_visual_gen_utils.py - unittest/_torch/visual_gen/test_warmup.py - unittest/_torch/visual_gen/test_teacache.py - unittest/_torch/visual_gen/test_cache_dit.py diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_multinode.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_multinode.py index 6f607fa87fc3..fddadb291d60 100644 --- a/tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_multinode.py +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_multinode.py @@ -223,10 +223,14 @@ def pre_set_event(): return e with ( - patch("tensorrt_llm.visual_gen.visual_gen._detect_external_launch", return_value=None), - patch("tensorrt_llm.visual_gen.visual_gen.mp.get_context", return_value=mock_ctx), - patch("tensorrt_llm.visual_gen.visual_gen.threading.Thread") as mock_thread_cls, - patch("tensorrt_llm.visual_gen.visual_gen.threading.Event", side_effect=pre_set_event), + patch( + "tensorrt_llm._torch.visual_gen.executor._detect_external_launch", return_value=None + ), + patch("tensorrt_llm._torch.visual_gen.executor.mp.get_context", return_value=mock_ctx), + patch("tensorrt_llm._torch.visual_gen.executor.threading.Thread") as mock_thread_cls, + patch( + "tensorrt_llm._torch.visual_gen.executor.threading.Event", side_effect=pre_set_event + ), patch.object(DiffusionRemoteClient, "_wait_ready"), ): mock_thread_cls.return_value = MagicMock() # thread.start() is a no-op diff --git a/tests/unittest/_torch/visual_gen/test_tensor_payload.py b/tests/unittest/_torch/visual_gen/test_tensor_payload.py new file mode 100644 index 000000000000..4ad482b886fd --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_tensor_payload.py @@ -0,0 +1,331 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for :mod:`tensorrt_llm.media.tensor_payload`. + +Covers the rank-aware batch semantics on +:meth:`VisualGenOutput.save` and :meth:`VisualGenOutput._save_bytes` +across both supported tokens (``"safetensors"`` and ``"pt"``). +""" + +from __future__ import annotations + +import io + +import pytest +import torch + +from tensorrt_llm.media.tensor_payload import ( + TENSOR_FORMATS, + infer_batch_size, + is_tensor_format, + serialize_visual_gen_output, +) +from tensorrt_llm.visual_gen.output import VisualGenOutput + + +def _safetensors_load(data: bytes) -> dict: + from safetensors.torch import load as load_safetensors + + return load_safetensors(data) + + +def _pt_load(data: bytes) -> dict: + return torch.load(io.BytesIO(data), weights_only=True) + + +def _make_image_output(batch: int = 1, h: int = 8, w: int = 8) -> VisualGenOutput: + """Image tensor uses canonical ``(B, H, W, C)`` shape.""" + img = torch.arange(batch * h * w * 3, dtype=torch.uint8).reshape(batch, h, w, 3) + return VisualGenOutput(request_id=1, image=img) + + +def _make_video_output(batch: int = 1, t: int = 2, h: int = 4, w: int = 4) -> VisualGenOutput: + """Video uses ``(B, T, H, W, C)``; LTX-2-style audio uses ``(B, channels, T_audio)``.""" + vid = torch.arange(batch * t * h * w * 3, dtype=torch.uint8).reshape(batch, t, h, w, 3) + audio = torch.arange(batch * 2 * 32, dtype=torch.float32).reshape(batch, 2, 32) / 100.0 + return VisualGenOutput( + request_id=2, + video=vid, + audio=audio, + frame_rate=24.0, + audio_sample_rate=16000, + ) + + +class TestIsTensorFormat: + def test_accepts_supported_tokens(self): + for token in TENSOR_FORMATS: + assert is_tensor_format(token) + + def test_rejects_encoder_tokens(self): + for token in ("png", "webp", "jpeg", "mp4", "avi", "auto", None, "npz"): + assert not is_tensor_format(token) + + +class TestInferBatchSize: + def test_image_rank_4_is_batched(self): + output = _make_image_output(batch=3) + assert infer_batch_size(output) == 3 + + def test_image_rank_3_is_unbatched(self): + output = VisualGenOutput(request_id=1, image=torch.zeros(8, 8, 3, dtype=torch.uint8)) + assert infer_batch_size(output) == 1 + + def test_video_rank_5_is_batched(self): + output = _make_video_output(batch=2) + assert infer_batch_size(output) == 2 + + def test_video_rank_4_is_unbatched(self): + """An unbatched video has shape ``(T, H, W, C)``; the frame axis + must not be confused with a batch dimension.""" + video = torch.zeros(8, 4, 4, 3, dtype=torch.uint8) + output = VisualGenOutput(request_id=1, video=video, frame_rate=12.0) + assert infer_batch_size(output) == 1 + + def test_inconsistent_batches_raise(self): + # Image rank-4 batch=2 vs video rank-5 batch=3 — must error out. + output = VisualGenOutput( + request_id=1, + image=torch.zeros(2, 4, 4, 3, dtype=torch.uint8), + video=torch.zeros(3, 2, 4, 4, 3, dtype=torch.uint8), + frame_rate=24.0, + ) + with pytest.raises(ValueError, match="Inconsistent batch sizes"): + infer_batch_size(output) + + def test_no_media_raises(self): + with pytest.raises(ValueError, match="carries no media"): + infer_batch_size(VisualGenOutput(request_id=1)) + + +@pytest.mark.parametrize("fmt", ["safetensors", "pt"]) +class TestSingleSavePath: + """A single path writes one logical output. Unbatched tensors and + ``batch == 1`` tensors save as-is; ``batch > 1`` raises ``ValueError`` + to force the caller to pass a list of paths.""" + + def test_unbatched_image_writes_full_tensor(self, fmt, tmp_path): + img = torch.arange(8 * 8 * 3, dtype=torch.uint8).reshape(8, 8, 3) + output = VisualGenOutput(request_id=1, image=img) + target = tmp_path / "img" + saved = output.save(target, format=fmt) + loaded = (_safetensors_load if fmt == "safetensors" else _pt_load)(saved.read_bytes()) + assert loaded["image"].shape == (8, 8, 3) + assert torch.equal(loaded["image"], img) + + def test_batched_image_single_path_raises(self, fmt, tmp_path): + output = _make_image_output(batch=3) + target = tmp_path / "img" + with pytest.raises(ValueError, match="batched tensor of size 3"): + output.save(target, format=fmt) + + def test_batched_video_single_path_raises(self, fmt, tmp_path): + output = _make_video_output(batch=2) + target = tmp_path / "vid" + with pytest.raises(ValueError, match="batched tensor of size 2"): + output.save(target, format=fmt) + + def test_path_suffix_is_normalized(self, fmt, tmp_path): + output = _make_image_output(batch=1) + saved = output.save(tmp_path / "no_ext", format=fmt) + assert saved.suffix == f".{fmt}" + + +@pytest.mark.parametrize("fmt", ["safetensors", "pt"]) +class TestListSavePath: + """A list of paths writes one payload per batch item.""" + + def test_batched_image_fans_out(self, fmt, tmp_path): + output = _make_image_output(batch=3) + paths = [tmp_path / f"img_{i}" for i in range(3)] + saved = output.save(paths, format=fmt) + assert len(saved) == 3 + for i, p in enumerate(saved): + loaded = (_safetensors_load if fmt == "safetensors" else _pt_load)(p.read_bytes()) + assert loaded["image"].shape == (8, 8, 3) + assert torch.equal(loaded["image"], output.image[i]) + + def test_path_count_mismatch_raises(self, fmt, tmp_path): + output = _make_image_output(batch=3) + with pytest.raises(ValueError, match="does not match batch size"): + output.save([tmp_path / "img_0", tmp_path / "img_1"], format=fmt) + + def test_batched_video_fans_out_with_audio_slice(self, fmt, tmp_path): + output = _make_video_output(batch=2) + paths = [tmp_path / f"vid_{i}" for i in range(2)] + saved = output.save(paths, format=fmt) + for i, p in enumerate(saved): + loaded = (_safetensors_load if fmt == "safetensors" else _pt_load)(p.read_bytes()) + assert loaded["video"].shape == (2, 4, 4, 3) + assert loaded["audio"].shape == (2, 32) + assert torch.equal(loaded["video"], output.video[i]) + assert torch.equal(loaded["audio"], output.audio[i]) + + +@pytest.mark.parametrize("fmt", ["safetensors", "pt"]) +class TestSaveBytes: + """``_save_bytes`` returns the same payload as :meth:`save` for the + bytes-based transport.""" + + def test_batch_index_slices_image(self, fmt): + output = _make_image_output(batch=2) + for i in range(2): + data = output._save_bytes(fmt, batch_index=i) + loaded = (_safetensors_load if fmt == "safetensors" else _pt_load)(data) + assert loaded["image"].shape == (8, 8, 3) + assert torch.equal(loaded["image"], output.image[i]) + + def test_batch_index_none_writes_unbatched_as_is(self, fmt): + output = VisualGenOutput( + request_id=1, + image=torch.zeros(4, 4, 3, dtype=torch.uint8), + ) + data = output._save_bytes(fmt, batch_index=None) + loaded = (_safetensors_load if fmt == "safetensors" else _pt_load)(data) + assert loaded["image"].shape == (4, 4, 3) + + def test_rejects_image_encoder_format(self, fmt): + # The bytes-based path is tensor-only; image/video encoders use + # the file-based ``save`` API. + output = _make_image_output(batch=1) + with pytest.raises(ValueError, match="tensor formats"): + output._save_bytes("png") + + +class TestSerializeDirect: + """Direct calls to :func:`serialize_visual_gen_output` for low-level + coverage of the rank-aware behavior.""" + + def test_unbatched_image_not_sliced(self): + output = VisualGenOutput(request_id=1, image=torch.zeros(7, 5, 3, dtype=torch.uint8)) + data = serialize_visual_gen_output(output, "safetensors", batch_index=0) + loaded = _safetensors_load(data) + # Height axis must survive; the helper must not confuse it with + # a batch axis on a rank-3 image. + assert loaded["image"].shape == (7, 5, 3) + + def test_unbatched_video_not_sliced(self): + video = torch.zeros(9, 4, 4, 3, dtype=torch.uint8) + output = VisualGenOutput(request_id=1, video=video, frame_rate=24.0) + data = serialize_visual_gen_output(output, "safetensors", batch_index=0) + loaded = _safetensors_load(data) + # Frame axis must survive on a rank-4 video. + assert loaded["video"].shape == (9, 4, 4, 3) + + +@pytest.mark.parametrize("fmt", ["safetensors", "pt"]) +class TestSaveMetadataOverrides: + """``frame_rate`` and ``audio_sample_rate`` kwargs on + :meth:`VisualGenOutput.save` and :meth:`VisualGenOutput._save_bytes` + override the corresponding fields on the output. Matches the + video-encoder path's existing override semantics so a caller who + fills in missing or stale metadata gets it into the serialized + payload as well.""" + + def _load(self, fmt, data): + return _safetensors_load(data) if fmt == "safetensors" else _pt_load(data) + + def test_save_overrides_unset_metadata(self, fmt, tmp_path): + """Output carries no rate fields; ``save`` overrides put the + right metadata into the payload.""" + video = torch.zeros(1, 2, 4, 4, 3, dtype=torch.uint8) + audio = torch.zeros(1, 2, 16, dtype=torch.float32) + output = VisualGenOutput(request_id=1, video=video, audio=audio) + target = tmp_path / "out" + saved = output.save(target, format=fmt, frame_rate=24.0, audio_sample_rate=16000) + loaded = self._load(fmt, saved.read_bytes()) + # Both serializers expose scalar metadata through the + # canonical ``load`` path: pt as native Python values, safetensors + # as 0-d tensors (which compare equal to the Python scalar). + assert loaded["frame_rate"] == 24.0 + assert loaded["audio_sample_rate"] == 16000 + if fmt == "safetensors": + # The string-keyed file header is preserved for consumers that + # use ``safe_open(...).metadata()`` instead of ``load()``. + data_bytes = output._save_bytes( + fmt, batch_index=0, frame_rate=24.0, audio_sample_rate=16000 + ) + import tempfile + + from safetensors import safe_open + + with tempfile.NamedTemporaryFile(suffix=".safetensors") as tf: + tf.write(data_bytes) + tf.flush() + with safe_open(tf.name, framework="pt") as f: + meta = f.metadata() or {} + assert meta.get("frame_rate") == "24.0" + assert meta.get("audio_sample_rate") == "16000" + + def test_save_overrides_take_precedence(self, fmt, tmp_path): + """Override values win even when the output has its own.""" + video = torch.zeros(1, 2, 4, 4, 3, dtype=torch.uint8) + output = VisualGenOutput( + request_id=1, video=video, frame_rate=12.0, audio_sample_rate=24000 + ) + target = tmp_path / "out" + saved = output.save(target, format=fmt, frame_rate=60.0, audio_sample_rate=48000) + loaded = self._load(fmt, saved.read_bytes()) + assert loaded["frame_rate"] == 60.0 + assert loaded["audio_sample_rate"] == 48000 + + def test_save_bytes_overrides(self, fmt): + """``_save_bytes`` honors the same overrides for the + ``b64_json`` transport.""" + video = torch.zeros(1, 2, 4, 4, 3, dtype=torch.uint8) + output = VisualGenOutput(request_id=1, video=video) + data = output._save_bytes(fmt, batch_index=0, frame_rate=30.0, audio_sample_rate=44100) + loaded = self._load(fmt, data) + assert loaded["frame_rate"] == 30.0 + assert loaded["audio_sample_rate"] == 44100 + + +class TestSaveFormatInference: + """When the caller omits ``format``, ``VisualGenOutput.save`` infers + the serializer from the path suffix — the same contract the image + and video encoders honor for ``.png`` / ``.mp4`` / etc.""" + + def test_safetensors_suffix_dispatches_to_tensor_path(self, tmp_path): + output = _make_image_output(batch=1) + saved = output.save(tmp_path / "out.safetensors") + assert saved.suffix == ".safetensors" + loaded = _safetensors_load(saved.read_bytes()) + # Reaching the tensor path means the image lives under the + # ``image`` key; an encoder fallback would have produced a PNG + # with no parsable safetensors structure. + assert loaded["image"].shape == (8, 8, 3) + + def test_pt_suffix_dispatches_to_tensor_path(self, tmp_path): + output = _make_image_output(batch=1) + saved = output.save(tmp_path / "out.pt") + assert saved.suffix == ".pt" + loaded = _pt_load(saved.read_bytes()) + assert loaded["image"].shape == (8, 8, 3) + + def test_list_path_inference_when_all_tensor(self, tmp_path): + output = _make_image_output(batch=2) + paths = [tmp_path / f"img_{i}.safetensors" for i in range(2)] + saved = output.save(paths) + assert all(p.suffix == ".safetensors" for p in saved) + for i, p in enumerate(saved): + loaded = _safetensors_load(p.read_bytes()) + assert torch.equal(loaded["image"], output.image[i]) + + def test_mixed_list_paths_skip_inference(self, tmp_path): + """A list of paths with mixed suffixes does not match a single + tensor format; inference returns ``None`` and the dispatch + falls through to the encoder path (no inferred-format wrong + file). The encoder behavior on mixed paths is owned by the + encoder layer and not asserted here.""" + from tensorrt_llm.visual_gen.output import _infer_format_from_path + + assert _infer_format_from_path([tmp_path / "a.safetensors", tmp_path / "b.png"]) is None + + def test_image_encoder_suffix_still_uses_encoder(self, tmp_path): + """A ``.png`` path with no explicit ``format`` keeps the + encoder path (regression guard for the inference logic).""" + output = _make_image_output(batch=1) + saved = output.save(tmp_path / "out.png") + assert saved.suffix == ".png" + # PNG magic bytes confirm the encoder path ran. + assert saved.read_bytes().startswith(b"\x89PNG\r\n\x1a\n") diff --git a/tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py b/tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py index 72d45ea89f02..e3d29f76eb87 100644 --- a/tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py +++ b/tests/unittest/_torch/visual_gen/test_trtllm_serve_e2e.py @@ -229,7 +229,7 @@ def test_health(self, server): assert resp.status_code == 200 @pytest.mark.parametrize( - "output_format,expected_content_type", + "format_,expected_content_type", [ pytest.param("avi", "video/x-msvideo", id="avi"), pytest.param( @@ -240,7 +240,7 @@ def test_health(self, server): ), ], ) - def test_t2v_sync(self, server, output_format, expected_content_type): + def test_t2v_sync(self, server, format_, expected_content_type): """Synchronous text-to-video via POST /v1/videos/generations.""" resp = requests.post( server.url_for("v1", "videos", "generations"), @@ -251,7 +251,7 @@ def test_t2v_sync(self, server, output_format, expected_content_type): "fps": 8, "num_inference_steps": 4, "seed": 42, - "output_format": output_format, + "format": format_, }, ) assert resp.status_code == 200, resp.text @@ -259,7 +259,7 @@ def test_t2v_sync(self, server, output_format, expected_content_type): assert len(resp.content) > 1000, "Video file too small" @pytest.mark.parametrize( - "output_format,expected_content_type", + "format_,expected_content_type", [ pytest.param("avi", "video/x-msvideo", id="avi"), pytest.param( @@ -270,7 +270,7 @@ def test_t2v_sync(self, server, output_format, expected_content_type): ), ], ) - def test_t2v_async_lifecycle(self, server, output_format, expected_content_type): + def test_t2v_async_lifecycle(self, server, format_, expected_content_type): """Async video generation: create job → poll → download → delete.""" base = server.url_for("v1", "videos") @@ -284,7 +284,7 @@ def test_t2v_async_lifecycle(self, server, output_format, expected_content_type) "fps": 8, "num_inference_steps": 4, "seed": 42, - "output_format": output_format, + "format": format_, }, ) assert create_resp.status_code == 202, create_resp.text @@ -349,7 +349,7 @@ def test_health(self, server): assert resp.status_code == 200 @pytest.mark.parametrize( - "output_format,expected_content_type", + "format_,expected_content_type", [ pytest.param("avi", "video/x-msvideo", id="avi"), pytest.param( @@ -360,7 +360,7 @@ def test_health(self, server): ), ], ) - def test_ti2v_sync(self, server, output_format, expected_content_type): + def test_ti2v_sync(self, server, format_, expected_content_type): """Synchronous image-to-video via multipart POST /v1/videos/generations.""" with open(_REF_IMAGE_PATH, "rb") as f: resp = requests.post( @@ -372,7 +372,7 @@ def test_ti2v_sync(self, server, output_format, expected_content_type): "fps": "8", "num_inference_steps": "4", "seed": "42", - "output_format": output_format, + "format": format_, }, files={ "input_reference": ("cat_piano.png", f, "image/png"), @@ -383,7 +383,7 @@ def test_ti2v_sync(self, server, output_format, expected_content_type): assert len(resp.content) > 1000, "Video file too small" @pytest.mark.parametrize( - "output_format,expected_content_type", + "format_,expected_content_type", [ pytest.param("avi", "video/x-msvideo", id="avi"), pytest.param( @@ -394,7 +394,7 @@ def test_ti2v_sync(self, server, output_format, expected_content_type): ), ], ) - def test_ti2v_async_lifecycle(self, server, output_format, expected_content_type): + def test_ti2v_async_lifecycle(self, server, format_, expected_content_type): """Async i2v: create job with image → poll → download → delete.""" base = server.url_for("v1", "videos") @@ -409,7 +409,7 @@ def test_ti2v_async_lifecycle(self, server, output_format, expected_content_type "fps": "8", "num_inference_steps": "4", "seed": "42", - "output_format": output_format, + "format": format_, }, files={ "input_reference": ("cat_piano.png", f, "image/png"), diff --git a/tests/unittest/_torch/visual_gen/test_trtllm_serve_endpoints.py b/tests/unittest/_torch/visual_gen/test_trtllm_serve_endpoints.py index 2b720da90185..105f9ad90307 100644 --- a/tests/unittest/_torch/visual_gen/test_trtllm_serve_endpoints.py +++ b/tests/unittest/_torch/visual_gen/test_trtllm_serve_endpoints.py @@ -36,6 +36,30 @@ # --------------------------------------------------------------------------- +def _assert_llm_envelope( + body: dict, + *, + code: int, + err_type: str = "BadRequestError", + message_contains: Optional[str] = None, +) -> None: + """Assert *body* is the visual-gen LLM-style error envelope. + + The envelope's wire shape is ``{"object": "error", "message": str, + "type": str, "code": int}`` with optional ``"param": str | None``. + ``object`` and ``param`` are returned by Pydantic's + ``ErrorResponse.model_dump`` and are stable across all visual-gen + error paths. + """ + assert set(body.keys()) == {"object", "message", "type", "param", "code"}, body + assert body["object"] == "error" + assert body["type"] == err_type + assert body["code"] == code + assert isinstance(body["message"], str) and body["message"] + if message_contains is not None: + assert message_contains in body["message"], body["message"] + + def _make_dummy_image_tensor(height: int = 64, width: int = 64) -> torch.Tensor: """Create a small dummy uint8 image tensor (H, W, C).""" return torch.randint(0, 256, (height, width, 3), dtype=torch.uint8) @@ -105,18 +129,46 @@ def __init__( audio_output: Optional[torch.Tensor] = None, should_fail: bool = False, batch_aware: bool = True, + validation_error: Optional[ValueError] = None, ): + from types import SimpleNamespace + self._image = image_output self._video = video_output self._audio = audio_output self._should_fail = should_fail self._batch_aware = batch_aware + self._validation_error = validation_error self._healthy = True self._req_counter = 0 # Captured arguments of the most recent generate / generate_async call, # used by tests to assert forwarded VisualGenParams fields. self.last_inputs = None self.last_params = None + # Stand-in for the coordinator-side executor proxy. The async video + # route reads ``default_generation_params`` / ``extra_param_specs`` + # directly off this attribute when running synchronous pre-flight + # validation. ``default_generation_params`` declares the universal + # fields the mock pipeline accepts so the validator doesn't + # reject legitimate width/height/num_frames/... requests; + # ``extra_param_specs`` lists a single known key so tests can + # exercise both the accept-known and reject-unknown paths. + from tensorrt_llm._torch.visual_gen.pipeline import ExtraParamSchema + + self.executor = SimpleNamespace( + default_generation_params={ + "height": 64, + "width": 64, + "num_inference_steps": 20, + "guidance_scale": 5.0, + "max_sequence_length": 64, + "num_frames": 8, + "frame_rate": 8.0, + }, + extra_param_specs={ + "stg_scale": ExtraParamSchema(type="float", default=1.0), + }, + ) def _maybe_batch(self, tensor, n): """Replicate a single tensor along a new leading batch dimension.""" @@ -129,6 +181,8 @@ def _maybe_batch(self, tensor, n): def generate(self, inputs=None, params=None) -> VisualGenOutput: self.last_inputs = inputs self.last_params = params + if self._validation_error is not None: + raise self._validation_error if self._should_fail: raise RuntimeError("Generation intentionally failed") n = getattr(params, "num_images_per_prompt", 1) if params else 1 @@ -143,6 +197,8 @@ def generate(self, inputs=None, params=None) -> VisualGenOutput: def generate_async(self, inputs=None, params=None) -> "MockVisualGenResult": self.last_inputs = inputs self.last_params = params + if self._validation_error is not None: + raise self._validation_error n = getattr(params, "num_images_per_prompt", 1) if params else 1 return MockVisualGenResult( request_id=self._next_request_id(), @@ -165,6 +221,19 @@ def default_params(self): return VisualGenParams() + @property + def extra_param_specs(self): + """Stand-in for VisualGen.extra_param_specs — empty by default so + every request ``extra_params`` key reaches the executor as + ``unknown_extra_param`` (matches a pipeline with no model-specific + knobs declared, like Flux or Wan 2.1).""" + return {} + + @property + def model(self): + """Stand-in for VisualGen.model — used by warn-on-set logic.""" + return "test-model" + def _check_health(self) -> bool: return self._healthy @@ -334,6 +403,7 @@ def _dummy_save_encoded_video(video, audio, output_path, frame_rate, audio_sampl # ========================================================================= +@pytest.mark.threadleak(enabled=False) # FileResponse spawns AnyIO worker threads class TestImageGeneration: def test_basic_image_generation_b64(self, image_client): resp = image_client.post( @@ -382,7 +452,12 @@ def test_image_generation_with_optional_params(self, image_client): assert params.guidance_scale == 7.5 assert params.negative_prompt == "blurry" - def test_image_generation_url_format_not_supported(self, image_client): + def test_image_generation_url_returns_fetchable_urls(self, image_client): + """``response_format='url'`` writes each generated image to + media storage and surfaces a server-relative HTTP URL pointing + at ``GET /v1/images/{id}/content?i=N``. The URL fetches the + image bytes back through the API instead of leaking the + on-disk path.""" resp = image_client.post( "/v1/images/generations", json={ @@ -390,7 +465,61 @@ def test_image_generation_url_format_not_supported(self, image_client): "response_format": "url", }, ) - assert resp.status_code == 501 + assert resp.status_code == 200 + body = resp.json() + assert len(body["data"]) >= 1 + url = body["data"][0]["url"] + # URL is an HTTP URL through the API content endpoint. + assert "/v1/images/" in url and "/content" in url + # Fetch via the same client to verify it works. + path = url.split("//", 1)[-1].split("/", 1)[1] + content = image_client.get("/" + path) + assert content.status_code == 200 + # PNG bytes start with the standard magic header. + assert content.content.startswith(b"\x89PNG\r\n\x1a\n") + assert content.headers["content-type"] == "image/png" + + def test_image_generation_safetensors_b64(self, image_client): + """Tensor formats return base64-encoded raw bytes; loading the + payload yields the engine tensors back.""" + from safetensors.torch import load as load_safetensors + + resp = image_client.post( + "/v1/images/generations", + json={ + "prompt": "Tensor cat", + "response_format": "b64_json", + "format": "safetensors", + }, + ) + assert resp.status_code == 200 + body = resp.json() + assert len(body["data"]) == 1 + b64 = body["data"][0]["b64_json"] + loaded = load_safetensors(base64.b64decode(b64)) + assert "image" in loaded + + def test_image_generation_pt_url(self, image_client): + """Tensor formats under ``response_format='url'`` write each + per-item payload to media storage and surface a fetchable + HTTP URL through the image content endpoint.""" + resp = image_client.post( + "/v1/images/generations", + json={ + "prompt": "Tensor dog", + "response_format": "url", + "format": "pt", + }, + ) + assert resp.status_code == 200 + url = resp.json()["data"][0]["url"] + assert "/v1/images/" in url and "/content" in url + path = url.split("//", 1)[-1].split("/", 1)[1] + content = image_client.get("/" + path) + assert content.status_code == 200 + assert content.headers["content-type"] == "application/octet-stream" + loaded = torch.load(BytesIO(content.content), weights_only=True) + assert "image" in loaded def test_image_generation_auto_size(self, image_client): resp = image_client.post( @@ -404,6 +533,8 @@ def test_image_generation_auto_size(self, image_client): assert resp.status_code == 200 def test_image_generation_failure(self, failing_client): + """Engine-side ``RuntimeError`` (non-validation) surfaces as HTTP 500; + the LLM envelope carries the error message.""" resp = failing_client.post( "/v1/images/generations", json={ @@ -411,10 +542,12 @@ def test_image_generation_failure(self, failing_client): "response_format": "b64_json", }, ) - assert resp.status_code == 400 + assert resp.status_code == 500 + _assert_llm_envelope(resp.json(), code=500, err_type="InternalServerError") def test_image_generation_invalid_size(self, image_client): - """Invalid size triggers RequestValidationError → custom handler → 400.""" + """Invalid size triggers a Pydantic ``RequestValidationError``; + the visual-gen-scoped handler emits the LLM-style 422 envelope.""" resp = image_client.post( "/v1/images/generations", json={ @@ -423,7 +556,8 @@ def test_image_generation_invalid_size(self, image_client): "size": "invalid", }, ) - assert resp.status_code == 400 + assert resp.status_code == 422 + _assert_llm_envelope(resp.json(), code=422, message_contains="size") def test_image_generation_null_output(self, tmp_path): """Generator returns VisualGenOutput with image=None.""" @@ -465,12 +599,15 @@ def test_image_generation_hd_quality(self, image_client): assert resp.status_code == 200 def test_missing_prompt_image_generation(self, image_client): - """Missing required field → RequestValidationError → custom handler → 400.""" + """Missing required field surfaces as a Pydantic + ``RequestValidationError`` and the visual-gen-scoped handler + returns the LLM-style 422 envelope.""" resp = image_client.post( "/v1/images/generations", json={}, ) - assert resp.status_code == 400 + assert resp.status_code == 422 + _assert_llm_envelope(resp.json(), code=422, message_contains="prompt") def test_image_generation_b64_no_save_image_no_disk_write(self, image_client, tmp_path): """Regression guard for NVBug 6064029. @@ -554,7 +691,7 @@ class TestImageEdit: """ def test_image_edit_returns_not_implemented(self, image_client): - """Valid request body still short-circuits to 501 NotImplemented.""" + """Valid request body short-circuits to 501 NotImplemented.""" b64_img = _b64_white_png_1x1() resp = image_client.post( "/v1/images/edits", @@ -569,16 +706,14 @@ def test_image_edit_returns_not_implemented(self, image_client): assert body.get("type") == "NotImplementedError" assert "not supported" in body.get("message", "").lower() - def test_missing_image_for_edit(self, image_client): - """Missing required field is rejected by FastAPI request validation - (400) before the 501 short-circuit, so this contract is unchanged.""" - resp = image_client.post( - "/v1/images/edits", - json={ - "prompt": "Edit without image", - }, - ) - assert resp.status_code == 400 + def test_image_edit_no_body_returns_not_implemented(self, image_client): + """The route doesn't parse a typed body; any incoming request still + gets 501, including ones that would have failed schema validation + before. Restore typed-body coverage when an edit pipeline lands.""" + resp = image_client.post("/v1/images/edits", json={"prompt": "Edit without image"}) + assert resp.status_code == 501 + body = resp.json() + assert body.get("type") == "NotImplementedError" # ========================================================================= @@ -663,20 +798,21 @@ def test_sync_video_generation_with_params(self, video_client): assert params.frame_rate == 8 assert params.num_frames == int(2.0 * 8) - def test_sync_video_generation_multipart(self, video_client): - # Use files={} with a dummy file to ensure multipart/form-data - dummy_file = BytesIO(b"") - resp = video_client.post( - "/v1/videos/generations", - data={ - "prompt": "Mountain sunrise", - "size": "64x64", - "seconds": "1.0", - "fps": "8", - }, - files={"_dummy": ("dummy", dummy_file, "application/octet-stream")}, - ) - # The server will parse fields; _dummy is ignored since it's not "input_reference" + def test_sync_video_generation_multipart(self, video_client, tmp_path): + """Multipart sync request with a real ``input_reference`` file.""" + ref_path = tmp_path / "ref.png" + Image.new("RGB", (4, 4), (64, 64, 64)).save(str(ref_path)) + with open(ref_path, "rb") as f: + resp = video_client.post( + "/v1/videos/generations", + data={ + "prompt": "Mountain sunrise", + "size": "64x64", + "seconds": "1.0", + "fps": "8", + }, + files={"input_reference": ("ref.png", f, "image/png")}, + ) assert resp.status_code == 200 assert len(resp.content) > 0 @@ -741,26 +877,48 @@ def test_sync_video_unsupported_content_type(self, video_client): assert resp.status_code == 400 def test_sync_video_missing_prompt_json(self, video_client): - """Missing required prompt → Pydantic ValidationError → 400.""" + """Missing required ``prompt`` surfaces the visual-gen 422 envelope.""" resp = video_client.post( "/v1/videos/generations", json={"size": "64x64"}, headers={"content-type": "application/json"}, ) - assert resp.status_code == 400 + assert resp.status_code == 422 + _assert_llm_envelope(resp.json(), code=422, message_contains="prompt") def test_sync_video_missing_prompt_multipart(self, video_client): - """Missing prompt in multipart form → ValueError → 400.""" + """Multipart body with a missing required field surfaces the + same LLM envelope as JSON so the wire contract is identical.""" dummy_file = BytesIO(b"") resp = video_client.post( "/v1/videos/generations", data={"size": "64x64"}, files={"_dummy": ("dummy", dummy_file, "application/octet-stream")}, ) - assert resp.status_code == 400 + assert resp.status_code == 422 + _assert_llm_envelope(resp.json(), code=422) - def test_sync_video_batch_n2(self, video_client): - """Sync video with n=2 should succeed and return the first video.""" + def test_sync_video_multipart_rejects_unknown_field(self, video_client): + """Strict multipart parsing rejects any form field that is not + on :class:`VideoGenerationRequest` with the same 422 envelope as + the JSON path.""" + dummy_file = BytesIO(b"") + resp = video_client.post( + "/v1/videos/generations", + data={ + "prompt": "Strict multipart", + "size": "64x64", + "seconds": "1.0", + "fps": "8", + "output_format": "mp4", + }, + files={"_dummy": ("dummy", dummy_file, "application/octet-stream")}, + ) + assert resp.status_code == 422 + _assert_llm_envelope(resp.json(), code=422, message_contains="output_format") + + def test_sync_video_rejects_top_level_n(self, video_client): + """Sync video has no top-level ``n``; it's rejected with 422.""" resp = video_client.post( "/v1/videos/generations", json={ @@ -772,8 +930,8 @@ def test_sync_video_batch_n2(self, video_client): }, headers={"content-type": "application/json"}, ) - assert resp.status_code == 200 - assert len(resp.content) > 0 + assert resp.status_code == 422 + _assert_llm_envelope(resp.json(), code=422) # ========================================================================= @@ -817,51 +975,80 @@ def test_async_video_job_metadata_fields(self, video_client): assert data["fps"] == 12 assert data["size"] == "64x64" - def test_async_video_multipart(self, video_client): - """Multipart encoding requires a file field to trigger the correct content-type.""" - dummy_file = BytesIO(b"") + def test_async_video_multipart(self, video_client, tmp_path): + """Multipart async request with a real ``input_reference`` file.""" + ref_path = tmp_path / "ref.png" + Image.new("RGB", (4, 4), (16, 16, 16)).save(str(ref_path)) + with open(ref_path, "rb") as f: + resp = video_client.post( + "/v1/videos", + data={ + "prompt": "A sunset", + "size": "64x64", + "seconds": "1.0", + "fps": "8", + }, + files={"input_reference": ("ref.png", f, "image/png")}, + ) + assert resp.status_code == 202 + + def test_async_video_rejects_top_level_n(self, video_client): + """Video has no top-level ``n``; it's rejected with 422 by ``extra=forbid``.""" resp = video_client.post( "/v1/videos", - data={ - "prompt": "A sunset", + json={ + "prompt": "Batch fireworks", "size": "64x64", - "seconds": "1.0", - "fps": "8", + "seconds": 1.0, + "fps": 8, + "n": 2, }, - files={"_dummy": ("dummy", dummy_file, "application/octet-stream")}, + headers={"content-type": "application/json"}, ) - assert resp.status_code == 202 + assert resp.status_code == 422 + _assert_llm_envelope(resp.json(), code=422) - def test_async_video_invalid_seconds(self, video_client): - """Seconds must be between 1.0 and 16.0. Validation error → 400.""" + def test_async_video_rejects_top_level_guidance_rescale(self, video_client): + """``guidance_rescale`` is per-model; must travel via ``extra_params``.""" resp = video_client.post( "/v1/videos", json={ - "prompt": "Too short", - "seconds": 0.1, + "prompt": "Bad knob", + "seconds": 1.0, "size": "64x64", "fps": 8, + "guidance_rescale": 0.7, }, headers={"content-type": "application/json"}, ) - assert resp.status_code == 400 + assert resp.status_code == 422 + _assert_llm_envelope(resp.json(), code=422) - def test_async_video_invalid_fps(self, video_client): - """Fps must be between 8 and 60. Validation error → 400.""" + def test_async_video_rejects_output_format(self, video_client): + """``output_format`` has been renamed to ``format``.""" resp = video_client.post( "/v1/videos", json={ - "prompt": "Bad fps", + "prompt": "Bad name", "seconds": 1.0, - "fps": 2, "size": "64x64", + "fps": 8, + "output_format": "mp4", }, headers={"content-type": "application/json"}, ) - assert resp.status_code == 400 - - def test_async_video_forwards_params(self, video_client): - """Ensure async video endpoint forwards VisualGenParams to generate_async.""" + assert resp.status_code == 422 + _assert_llm_envelope(resp.json(), code=422) + + def test_async_video_accepts_request_with_params(self, video_client): + """The async ``/v1/videos`` route accepts the full request shape and + returns 202 with a queued job. Per-field forwarding is asserted + only against the *sync* routes — the async path deep-copies the + request before enqueuing and the background task runs out-of-order + with the test, so ``mock_gen.last_params`` is not a reliable + capture point for merge-semantics here. Direct conversion-helper + tests cover the field-by-field overlay instead. + """ resp = video_client.post( "/v1/videos", json={ @@ -877,40 +1064,22 @@ def test_async_video_forwards_params(self, video_client): headers={"content-type": "application/json"}, ) assert resp.status_code == 202 - video_id = resp.json()["id"] - - # The background task calls generate_async lazily — drive the event - # loop via status polling until the job completes. - import time as _time - - deadline = _time.time() + 5 - while _time.time() < deadline: - meta = video_client.get(f"/v1/videos/{video_id}").json() - if meta.get("status") in ("completed", "failed"): - break - _time.sleep(0.05) + data = resp.json() + assert data["status"] == "queued" + assert data["object"] == "video" + assert data["prompt"] == "Rainy street" + assert data["id"].startswith("video_") - params = video_client.mock_gen.last_params - assert video_client.mock_gen.last_inputs == "Rainy street" - assert params.width == 128 - assert params.height == 64 - assert params.num_inference_steps == 12 - assert params.guidance_scale == 6.0 - assert params.seed == 7 - assert params.negative_prompt == "noise" - assert params.frame_rate == 10 - assert params.num_frames == int(2.0 * 10) - - def test_async_video_batch_n2(self, video_client): - """Async video with n=2 should accept the request and return 202.""" + def test_async_video_accepts_extra_params(self, video_client): + """Per-model overflow travels through ``extra_params``.""" resp = video_client.post( "/v1/videos", json={ - "prompt": "Batch fireworks", + "prompt": "Stylized fireworks", "size": "64x64", "seconds": 1.0, "fps": 8, - "n": 2, + "extra_params": {"stg_scale": 1.5}, }, headers={"content-type": "application/json"}, ) @@ -1150,3 +1319,798 @@ def test_async_video_null_output_updates_job_status(self, tmp_path): assert "output.video is None" in data["error"] os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None) + + +# ========================================================================= +# Route-level engine-validation-error handling +# ========================================================================= + + +def _make_validation_error(param: str = "stg_sclae"): + """Build the kind of stock ``ValueError`` ``validate_visual_gen_params`` + raises when extra_params contains an unknown key. Tests inject this + onto the mock so the routes' ``except ValueError`` arm fires the same + way it would in production.""" + return ValueError( + f"Parameter validation failed:\n - Unknown extra_params ['{param}']. Supported: []" + ) + + +class TestRouteEngineValidationError: + """When the engine raises ``ValueError`` (request-shape problem), the + image and sync-video routes return HTTP 400 with the LLM envelope + built from the exception message. The async-video route runs the + same check synchronously via ``validate_visual_gen_params`` so an + unknown ``extra_params`` key surfaces as 400 immediately instead of + becoming a queued 202 whose background task later fails.""" + + def test_image_route_renders_validation_error_at_400(self, tmp_path): + os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path) + try: + gen = MockVisualGen( + image_output=_make_dummy_image_tensor(), + validation_error=_make_validation_error(), + ) + client = _create_server(gen) + resp = client.post( + "/v1/images/generations", + json={ + "prompt": "trigger validation error", + "response_format": "b64_json", + "extra_params": {"stg_sclae": 1.0}, + }, + ) + assert resp.status_code == 400 + _assert_llm_envelope( + resp.json(), + code=400, + message_contains="stg_sclae", + ) + finally: + os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None) + + def test_sync_video_route_renders_validation_error_at_400(self, tmp_path): + os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path) + try: + gen = MockVisualGen( + video_output=_make_dummy_video_tensor(), + validation_error=_make_validation_error(), + ) + client = _create_server(gen) + resp = client.post( + "/v1/videos/generations", + json={ + "prompt": "trigger validation error", + "size": "64x64", + "seconds": 1.0, + "fps": 8, + "extra_params": {"stg_sclae": 1.0}, + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 400 + _assert_llm_envelope( + resp.json(), + code=400, + message_contains="stg_sclae", + ) + finally: + os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None) + + def test_image_route_serialization_value_error_returns_500(self, tmp_path, monkeypatch): + """Server-side serialization failures map to 500, not 400. + + ``infer_batch_size`` / ``serialize_visual_gen_output`` raise + ``ValueError`` for conditions on the server's own output + (no media tensor, inconsistent multi-modal batch). The image + route must render those as 500 — the client's request was + valid; the server failed to serialize its own output. + """ + os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path) + try: + gen = MockVisualGen(image_output=_make_dummy_image_tensor()) + + def _raise_server_side(*args, **kwargs): + raise ValueError("Cannot infer batch size: carries no media tensor.") + + # Force the tensor-format branch to hit a server-side ValueError + # in the serialization region (outside the pre-generation try). + monkeypatch.setattr( + "tensorrt_llm.media.tensor_payload.infer_batch_size", + _raise_server_side, + ) + client = _create_server(gen) + resp = client.post( + "/v1/images/generations", + json={ + "prompt": "trigger serialization failure", + "response_format": "b64_json", + "format": "safetensors", + }, + ) + assert resp.status_code == 500 + finally: + os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None) + + def test_async_video_route_rejects_validation_error_synchronously(self, tmp_path): + """``/v1/videos`` calls ``validate_visual_gen_params`` against the + mock's executor metadata before queuing; the mock's + ``extra_param_specs={}`` causes any unknown extra to be rejected + with a stock ``ValueError`` which the route's ``except ValueError`` + arm renders as HTTP 400.""" + os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = str(tmp_path) + try: + gen = MockVisualGen(video_output=_make_dummy_video_tensor()) + client = _create_server(gen) + resp = client.post( + "/v1/videos", + json={ + "prompt": "trigger validation error", + "size": "64x64", + "seconds": 1.0, + "fps": 8, + "extra_params": {"stg_sclae": 1.0}, + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 400 + _assert_llm_envelope( + resp.json(), + code=400, + message_contains="stg_sclae", + ) + finally: + os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None) + + +# ========================================================================= +# Non-visual-gen routes keep FastAPI's default validation response +# ========================================================================= + + +class TestNonVisualGenValidationResponse: + """Validation failures on non-visual-gen roles use the shared + ``OpenAIServer`` response shape (HTTP 400 + ``{"error": ...}``) + that existing integration coverage and clients expect (e.g. + ``test_malformed_json_request``). Only the visual-gen role swaps + in the LLM envelope. + + The assertion is checked at the handler-closure level: rebuild + the exact dispatch installed in :meth:`OpenAIServer.__init__` + against a minimal FastAPI app so the assertion stays narrow and + the test doesn't need to spin up a full LLM-role server.""" + + def _build_app_with_dispatch(self, role): + """Return a FastAPI app wired with the production handler + dispatch, where ``role`` controls the branch the handler takes + on a ``RequestValidationError``.""" + from fastapi import FastAPI + from fastapi.exceptions import RequestValidationError + from fastapi.responses import JSONResponse + from pydantic import BaseModel + + app = FastAPI() + + class _Body(BaseModel): + messages: list + + @app.post("/route") + async def _route(body: _Body): + return {"ok": True} + + @app.exception_handler(RequestValidationError) + async def _handler(_, exc): + if role == "VISUAL_GEN": + return _llm_envelope_branch(exc) + return JSONResponse(status_code=400, content={"error": str(exc)}) + + # Mirror :meth:`OpenAIServer._create_visual_gen_validation_error_response` + # inline so the test does not depend on instance state. + def _llm_envelope_branch(exc): + from http import HTTPStatus + + from tensorrt_llm.serve.openai_protocol import ErrorResponse + + error = ErrorResponse( + message="Validation failed", + type="BadRequestError", + code=HTTPStatus.UNPROCESSABLE_ENTITY.value, + ) + return JSONResponse( + content=error.model_dump(), + status_code=HTTPStatus.UNPROCESSABLE_ENTITY.value, + ) + + return app + + def test_non_visual_gen_role_returns_shared_400_error_body(self): + """Non-visual-gen roles return HTTP 400 with the shared + ``{"error": str(exc)}`` body that ``test_malformed_json_request`` + and existing clients depend on.""" + client = TestClient(self._build_app_with_dispatch(role="CONTEXT")) + resp = client.post("/route", json={"not_messages": []}) + assert resp.status_code == 400 + body = resp.json() + assert "error" in body + assert isinstance(body["error"], str) + # The visual-gen LLM envelope must not leak into non-VG paths. + assert "object" not in body + assert "type" not in body + assert "code" not in body + + def test_visual_gen_role_uses_llm_envelope(self): + client = TestClient(self._build_app_with_dispatch(role="VISUAL_GEN")) + resp = client.post("/route", json={"not_messages": []}) + assert resp.status_code == 422 + body = resp.json() + assert body["type"] == "BadRequestError" + assert body["code"] == 422 + assert "message" in body + + +# ========================================================================= +# Tensor-format response coverage on the video routes +# ========================================================================= + + +@pytest.mark.threadleak(enabled=False) # FileResponse spawns AnyIO worker threads +class TestVideoTensorResponse: + """The sync route emits tensor payloads as a single file under + ``response_format='url'`` and as base64-encoded bytes under + ``response_format='b64_json'``. The async route persists the + payload to media storage; ``GET /v1/videos/{id}/content`` serves + the file with ``application/octet-stream``.""" + + def _post_sync(self, video_client, fmt: str, response_format: str): + return video_client.post( + "/v1/videos/generations", + json={ + "prompt": f"tensor video {fmt}", + "size": "32x32", + "seconds": 1.0, + "fps": 8, + "format": fmt, + "response_format": response_format, + }, + headers={"content-type": "application/json"}, + ) + + @pytest.mark.parametrize("fmt", ["safetensors", "pt"]) + def test_sync_tensor_url_returns_file_with_correct_suffix(self, video_audio_client, fmt): + resp = self._post_sync(video_audio_client, fmt, "url") + assert resp.status_code == 200 + ext = f".{fmt}" + # The content-disposition header carries the on-disk filename. + disp = resp.headers.get("content-disposition", "") + assert ext in disp, disp + # And the payload itself round-trips. + if fmt == "safetensors": + from safetensors.torch import load as load_safetensors + + loaded = load_safetensors(resp.content) + else: + loaded = torch.load(BytesIO(resp.content), weights_only=True) + assert "video" in loaded + + @pytest.mark.parametrize("fmt", ["safetensors", "pt"]) + def test_sync_tensor_b64_returns_decodable_payload(self, video_audio_client, fmt): + resp = self._post_sync(video_audio_client, fmt, "b64_json") + assert resp.status_code == 200 + data = resp.json() + assert data["format"] == fmt + assert "b64_json" in data + raw = base64.b64decode(data["b64_json"]) + if fmt == "safetensors": + from safetensors.torch import load as load_safetensors + + loaded = load_safetensors(raw) + else: + loaded = torch.load(BytesIO(raw), weights_only=True) + assert "video" in loaded + + @pytest.mark.parametrize("fmt", ["safetensors", "pt"]) + def test_async_tensor_persists_and_serves(self, video_audio_client, fmt, tmp_path): + import time as _time + + client = video_audio_client + resp = client.post( + "/v1/videos", + json={ + "prompt": f"async tensor {fmt}", + "size": "32x32", + "seconds": 1.0, + "fps": 8, + "format": fmt, + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 202 + video_id = resp.json()["id"] + + # Drive the background task to completion via polling. + deadline = _time.time() + 5 + while _time.time() < deadline: + status = client.get(f"/v1/videos/{video_id}").json().get("status") + if status in ("completed", "failed"): + break + _time.sleep(0.05) + + content = client.get(f"/v1/videos/{video_id}/content") + assert content.status_code == 200 + # The server returns ``application/octet-stream`` for tensor payloads. + assert content.headers["content-type"] == "application/octet-stream" + if fmt == "safetensors": + from safetensors.torch import load as load_safetensors + + loaded = load_safetensors(content.content) + else: + loaded = torch.load(BytesIO(content.content), weights_only=True) + assert "video" in loaded + + +@pytest.mark.threadleak(enabled=False) # FileResponse spawns AnyIO worker threads +class TestVideoEncoderB64Response: + """The sync video route's encoder branch (``mp4``/``avi``/``auto``) + honors ``response_format='b64_json'`` by base64-encoding the + encoded video bytes; ``response_format='url'`` keeps the + ``FileResponse`` download.""" + + def test_sync_encoder_b64_json_returns_base64_payload(self, video_client): + resp = video_client.post( + "/v1/videos/generations", + json={ + "prompt": "encoded b64", + "size": "32x32", + "seconds": 1.0, + "fps": 8, + "format": "avi", + "response_format": "b64_json", + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["format"] in {"mp4", "avi"} + assert "b64_json" in body + raw = base64.b64decode(body["b64_json"]) + # Non-empty encoded bytes — exact format verification is the + # encoder layer's domain. + assert len(raw) > 0 + + def test_sync_encoder_url_keeps_file_response(self, video_client): + resp = video_client.post( + "/v1/videos/generations", + json={ + "prompt": "encoded url", + "size": "32x32", + "seconds": 1.0, + "fps": 8, + "format": "avi", + "response_format": "url", + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 200 + # FileResponse for an AVI carries ``video/x-msvideo``. + assert resp.headers["content-type"] == "video/x-msvideo" + + +class TestVideoTimingValidation: + """Numeric optionals on ``VideoGenerationRequest`` reject zero / + negative values so divisions and frame-count math downstream can + trust the value.""" + + @pytest.mark.parametrize( + "field,value", + [ + ("fps", 0), + ("frame_rate", -1), + ("num_frames", 0), + ("num_frames", -3), + ("seconds", 0), + ("seconds", -2.5), + ], + ) + def test_non_positive_timing_field_rejected(self, video_client, field, value): + resp = video_client.post( + "/v1/videos/generations", + json={ + "prompt": "bad timing", + "size": "32x32", + field: value, + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 422 + _assert_llm_envelope(resp.json(), code=422, message_contains=field) + + +class TestImageResponseFormatMetadata: + """``ImageGenerationResponse.output_format`` reflects the + requested encoding so clients that introspect the response know + how to decode the bytes / read the URL.""" + + @pytest.mark.parametrize( + "fmt", + ["png", "webp", "jpeg", "safetensors", "pt"], + ) + def test_response_carries_requested_format(self, image_client, fmt): + resp = image_client.post( + "/v1/images/generations", + json={ + "prompt": f"metadata for {fmt}", + "response_format": "b64_json", + "format": fmt, + }, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["output_format"] == fmt + + +@pytest.mark.threadleak(enabled=False) # FileResponse spawns AnyIO worker threads +class TestVideoZeroFrameDerivationRejected: + """``seconds * frame_rate`` that floors to zero frames must be + rejected with HTTP 400 + LLM envelope rather than reaching the + encoder with a 0-frame video.""" + + def test_subsecond_seconds_below_one_frame_returns_400(self, video_client): + resp = video_client.post( + "/v1/videos/generations", + json={ + "prompt": "way too short", + "size": "32x32", + "seconds": 0.01, + "fps": 8, + }, + headers={"content-type": "application/json"}, + ) + # int(0.01 * 8) == 0 — conversion raises ValueError → 400. + assert resp.status_code == 400 + _assert_llm_envelope( + resp.json(), + code=400, + message_contains="Derived frame count", + ) + + def test_seconds_without_frame_rate_returns_400(self, video_client): + """``seconds`` set but neither the request nor the pipeline default + declares a ``frame_rate``: the parser must reject the request with + HTTP 400 instead of silently dropping the duration and returning the + pipeline's default ``num_frames``.""" + resp = video_client.post( + "/v1/videos/generations", + json={ + "prompt": "duration without fps", + "size": "32x32", + "seconds": 1.0, + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 400 + _assert_llm_envelope( + resp.json(), + code=400, + message_contains="frame_rate", + ) + + def test_explicit_num_frames_one_is_accepted(self, video_client): + """The caller can bypass the derivation by passing ``num_frames`` + directly; the request must succeed.""" + resp = video_client.post( + "/v1/videos/generations", + json={ + "prompt": "explicit single frame", + "size": "32x32", + "num_frames": 1, + "fps": 8, + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 200 + + +class TestImageBatchCap: + """``ImageGenerationRequest.n`` is capped at 10 to bound resource + usage. ``n=10`` is accepted; ``n=11`` and ``n=100000`` are + rejected at the schema layer with HTTP 422 + LLM envelope.""" + + def test_n_equal_to_ten_accepted(self, image_client): + resp = image_client.post( + "/v1/images/generations", + json={ + "prompt": "ten images", + "response_format": "b64_json", + "size": "32x32", + "n": 10, + }, + ) + assert resp.status_code == 200 + assert len(resp.json()["data"]) == 10 + + @pytest.mark.parametrize("n", [11, 100000]) + def test_n_above_cap_rejected(self, image_client, n): + resp = image_client.post( + "/v1/images/generations", + json={ + "prompt": "too many", + "response_format": "b64_json", + "size": "32x32", + "n": n, + }, + ) + assert resp.status_code == 422 + _assert_llm_envelope(resp.json(), code=422, message_contains="n") + + +@pytest.mark.threadleak(enabled=False) # FileResponse spawns AnyIO worker threads +class TestVideoFrameBudgetCap: + """Upper bounds keep unbounded work / memory requests from reaching + the engine. The defaults (a minute of video at 120 fps) are + generous enough for common workloads; clients hitting the cap can + raise it at deployment time.""" + + @pytest.mark.parametrize( + "field,value,boundary", + [ + ("num_frames", 7200, "accepted"), + ("num_frames", 7201, "rejected"), + ("num_frames", 1_000_000, "rejected"), + ("seconds", 60.0, "accepted"), + ("seconds", 60.1, "rejected"), + ("seconds", 1.0e9, "rejected"), + ("fps", 120.0, "accepted"), + ("fps", 120.1, "rejected"), + ("fps", 1.0e6, "rejected"), + ], + ) + def test_frame_budget_bounds(self, video_client, field, value, boundary): + payload = { + "prompt": "boundary", + "size": "32x32", + } + if field != "num_frames": + # Pair seconds/fps with a sane partner to avoid the + # derived-zero-frames check; pass num_frames otherwise. + payload.update({"seconds": 1.0, "fps": 8}) + payload[field] = value + resp = video_client.post( + "/v1/videos/generations", + json=payload, + headers={"content-type": "application/json"}, + ) + if boundary == "accepted": + # The schema accepts the value at the boundary. The + # downstream pipeline may still 200 or 500 depending on + # the mock's tensor shape; the relevant assertion is that + # the request did not fall into the schema-rejection path. + assert resp.status_code != 422, resp.text + else: + assert resp.status_code == 422 + _assert_llm_envelope(resp.json(), code=422, message_contains=field) + + +class TestVideoJobFractionalFps: + """``VideoJob.fps`` is a float so cinematic frame rates like + 23.976 / 29.97 round-trip through the queued metadata instead of + being truncated to int.""" + + @pytest.mark.parametrize("rate", [23.976, 29.97, 59.94]) + def test_async_job_metadata_preserves_fractional_fps(self, video_client, rate): + resp = video_client.post( + "/v1/videos", + json={ + "prompt": "fractional fps", + "size": "32x32", + "seconds": 1.0, + "fps": rate, + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 202 + body = resp.json() + assert body["fps"] == rate + + def test_async_job_metadata_uses_resolved_default_fps(self, video_client): + """When the request omits ``fps``/``frame_rate``, the queued + ``VideoJob`` reports the pipeline-default rate that the + conversion layer resolved on ``params.frame_rate`` — not + ``None`` — so polling clients see accurate metadata for a + video encoded at the model default.""" + # Force a known default on the mock pipeline so the assertion + # is deterministic. ``MockVisualGen.default_params`` builds a + # fresh ``VisualGenParams``; patching the property here lets + # the test pretend the pipeline default is 12 fps. + from tensorrt_llm.visual_gen import VisualGenParams + + class _FixedDefaultGen(MockVisualGen): + @property + def default_params(self): + return VisualGenParams(frame_rate=12.0) + + gen = _FixedDefaultGen(video_output=_make_dummy_video_tensor()) + # The fixture installs media storage env vars; mirror that. + os.environ["TRTLLM_MEDIA_STORAGE_PATH"] = ( + os.path.dirname(video_client.app.state.__dict__.get("media_storage_path", "/tmp/_vg")) + or "/tmp/_vg" + ) + try: + client = _create_server(gen) + resp = client.post( + "/v1/videos", + json={ + "prompt": "no fps sent", + "size": "32x32", + "seconds": 1.0, + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 202 + body = resp.json() + assert body["fps"] == 12.0 + finally: + os.environ.pop("TRTLLM_MEDIA_STORAGE_PATH", None) + + +def _raise_value_error(_fmt): + raise ValueError("ffmpeg not available; encoder format unsupported") + + +def _raise_runtime_error(_fmt): + raise RuntimeError("MP4 (H.264) format requires ffmpeg to be installed.") + + +@pytest.mark.threadleak(enabled=False) # FileResponse spawns AnyIO worker threads +class TestVideoEncoderFailsFast: + """When an encoder format can't be resolved, the sync and async + video routes must reject the request before any GPU generation + runs. ``resolve_video_format`` raises ``ValueError`` for genuinely + unsupported format strings and ``RuntimeError`` for the + missing-ffmpeg case on ``format='mp4'``; both must surface as a + 400, not a 500.""" + + @pytest.mark.parametrize( + "raiser", + [_raise_value_error, _raise_runtime_error], + ids=["unsupported_format", "missing_ffmpeg"], + ) + def test_sync_route_fails_before_generate(self, video_client, monkeypatch, raiser): + from tensorrt_llm.serve import openai_video_routes as routes + + monkeypatch.setattr(routes, "resolve_video_format", raiser) + # Record whether the generator was called so the assertion + # locks in the fail-fast contract. + video_client.mock_gen.last_inputs = None + resp = video_client.post( + "/v1/videos/generations", + json={ + "prompt": "mp4 without ffmpeg", + "size": "32x32", + "seconds": 1.0, + "fps": 8, + "format": "mp4", + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 400 + assert video_client.mock_gen.last_inputs is None, ( + "generate() must not run when the encoder format is unsupported" + ) + + @pytest.mark.parametrize( + "raiser", + [_raise_value_error, _raise_runtime_error], + ids=["unsupported_format", "missing_ffmpeg"], + ) + def test_async_route_fails_before_queue(self, video_client, monkeypatch, raiser): + from tensorrt_llm.serve import openai_video_routes as routes + + monkeypatch.setattr(routes, "resolve_video_format", raiser) + video_client.mock_gen.last_inputs = None + resp = video_client.post( + "/v1/videos", + json={ + "prompt": "mp4 without ffmpeg", + "size": "32x32", + "seconds": 1.0, + "fps": 8, + "format": "mp4", + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 400 + assert video_client.mock_gen.last_inputs is None + + def test_sync_route_tensor_format_unaffected(self, video_client, monkeypatch): + """Tensor formats have no encoder dependency; a broken + ``resolve_video_format`` must not affect them.""" + from tensorrt_llm.serve import openai_video_routes as routes + + monkeypatch.setattr(routes, "resolve_video_format", _raise_value_error) + resp = video_client.post( + "/v1/videos/generations", + json={ + "prompt": "tensor unaffected", + "size": "32x32", + "seconds": 1.0, + "fps": 8, + "format": "safetensors", + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.threadleak(enabled=False) # FileResponse spawns AnyIO worker threads +class TestAsyncVideoB64JsonTransport: + """``POST /v1/videos`` persists the requested ``response_format`` on + the queued job. ``GET /v1/videos/{id}/content`` honors it: + ``url`` (or unset) returns a ``FileResponse`` download; + ``b64_json`` returns a JSON envelope with the encoded bytes + base64-inlined.""" + + def _drive_job_to_completion(self, client, video_id): + import time as _time + + deadline = _time.time() + 5 + while _time.time() < deadline: + status = client.get(f"/v1/videos/{video_id}").json().get("status") + if status in ("completed", "failed"): + return status + _time.sleep(0.05) + return None + + def test_async_b64_json_returned_at_get_content(self, video_client): + resp = video_client.post( + "/v1/videos", + json={ + "prompt": "async base64", + "size": "32x32", + "seconds": 1.0, + "fps": 8, + "format": "avi", + "response_format": "b64_json", + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 202 + job = resp.json() + assert job["response_format"] == "b64_json" + + status = self._drive_job_to_completion(video_client, job["id"]) + assert status == "completed" + + content = video_client.get(f"/v1/videos/{job['id']}/content") + assert content.status_code == 200 + body = content.json() + assert set(body) >= {"id", "format", "b64_json"} + assert body["id"] == job["id"] + # The encoded payload decodes to non-empty bytes. + raw = base64.b64decode(body["b64_json"]) + assert len(raw) > 0 + + def test_async_url_still_returns_file_response(self, video_client): + """Default and explicit ``response_format='url'`` keep the + existing ``FileResponse`` behavior.""" + resp = video_client.post( + "/v1/videos", + json={ + "prompt": "async url", + "size": "32x32", + "seconds": 1.0, + "fps": 8, + "format": "avi", + "response_format": "url", + }, + headers={"content-type": "application/json"}, + ) + assert resp.status_code == 202 + job = resp.json() + assert job["response_format"] == "url" + + self._drive_job_to_completion(video_client, job["id"]) + content = video_client.get(f"/v1/videos/{job['id']}/content") + assert content.status_code == 200 + # AVI FileResponse carries ``video/x-msvideo``; the b64_json + # branch would have set ``application/json``. + assert content.headers["content-type"] == "video/x-msvideo" diff --git a/tests/unittest/_torch/visual_gen/test_utils.py b/tests/unittest/_torch/visual_gen/test_utils.py index 594b6c530063..bffe0d5ad9a3 100644 --- a/tests/unittest/_torch/visual_gen/test_utils.py +++ b/tests/unittest/_torch/visual_gen/test_utils.py @@ -60,7 +60,7 @@ def test_shard_rope_passthrough_when_inactive(self): s = SequenceSharder(size=1, rank=0, group=None) cos = torch.randn(1, 4, 8) rope = (cos, cos) - assert s.shard_rope(rope, seq_len=4) is rope + assert s.shard_rope(rope, seq_len=4, seq_dim=1) is rope def test_disable_enable_no_collectives(self): s = SequenceSharder(size=4, rank=0, group=None) @@ -102,20 +102,24 @@ def test_shard_rope_bsd_layout(self): B, S, D = 1, 8, 16 cos = torch.arange(B * S * D).view(B, S, D).float() sin = cos + 1000 - out = s.shard_rope((cos, sin), seq_len=S) + out = s.shard_rope((cos, sin), seq_len=S, seq_dim=1) assert out is not None oc, osin = out assert oc.shape == (1, 4, 16) assert torch.equal(oc, cos[:, 4:8].contiguous()) - def test_shard_rope_ambiguous_returns_unchanged(self): - """Two axes equal seq_len → no unique dim → passthrough.""" + def test_shard_rope_explicit_seq_dim_handles_square_layout(self): + """``shard_rope`` requires an explicit ``seq_dim``. Square + ``(B, S, S)`` layouts dispatch on the caller-supplied axis; + the helper does not infer the sequence dimension.""" s = SequenceSharder(size=2, rank=0, group=None) S = 8 cos = torch.zeros(2, S, S) sin = torch.ones(2, S, S) - rope = (cos, sin) - assert s.shard_rope(rope, seq_len=S) is rope + out = s.shard_rope((cos, sin), seq_len=S, seq_dim=1) + assert out is not None + oc, _osin = out + assert oc.shape == (2, 4, S) class TestSequenceSharderFromVgm: @@ -124,10 +128,14 @@ def test_from_vgm_none(self): assert s.size == 1 and s.rank == 0 and not s.is_active def test_from_vgm_head_divisibility(self): + # ``seq_group`` is a callable returning the process group on + # the current ``SequenceSharder`` API; the stub uses a fresh + # ``object()`` to stand in for a real ``ProcessGroup``. + stub_group = object() vgm = SimpleNamespace( seq_size=4, seq_rank=0, - seq_group=None, + seq_group=lambda: stub_group, ulysses_size=2, ) SequenceSharder.from_vgm(vgm, num_attention_heads=8, num_kv_heads=4) diff --git a/tests/unittest/_torch/visual_gen/test_visual_gen_args.py b/tests/unittest/_torch/visual_gen/test_visual_gen_args.py index 540cc0c7f4c8..53604d7da2c4 100644 --- a/tests/unittest/_torch/visual_gen/test_visual_gen_args.py +++ b/tests/unittest/_torch/visual_gen/test_visual_gen_args.py @@ -216,7 +216,7 @@ def test_quant_config_dict_passthrough(self): assert dwq is True def test_quant_config_dict_does_not_leak_dynamic_flags(self): - """AC-5: dynamic flags are not part of the VisualGenArgs schema.""" + """Dynamic quant flags are not part of the VisualGenArgs schema.""" args = VisualGenArgs( model="/tmp/model", quant_config={"quant_algo": "FP8", "dynamic": True}, @@ -413,15 +413,24 @@ def test_list_of_strings_input(self): req = call_args[0] assert req.prompt == ["a sunset", "a city"] - def test_params_default_none(self): - """Omitting params passes None; executor materializes defaults later.""" + def test_params_default_materializes_visual_gen_params(self): + """Omitting params materializes a fresh ``VisualGenParams`` at the + enqueue site, so the executor never sees ``params is None``.""" + from tensorrt_llm.visual_gen import VisualGenParams + vg = self._make_visual_gen_with_mock_executor() vg.generate_async(inputs="a cat") call_args = vg.executor.enqueue_requests.call_args[0][0] req = call_args[0] - assert req.params is None + assert isinstance(req.params, VisualGenParams) + # ``seed`` is materialized at the public Python boundary so every + # downstream layer sees a concrete int; the other universal fields + # stay ``None`` until the executor applies pipeline defaults in + # ``DiffusionExecutor._merge_defaults``. + assert isinstance(req.params.seed, int) + assert req.params.height is None def test_negative_prompt_via_params(self): """negative_prompt is passed through params, not inputs.""" diff --git a/tests/unittest/_torch/visual_gen/test_visual_gen_params.py b/tests/unittest/_torch/visual_gen/test_visual_gen_params.py index 01a7d0c210fb..e69c81d21f1b 100644 --- a/tests/unittest/_torch/visual_gen/test_visual_gen_params.py +++ b/tests/unittest/_torch/visual_gen/test_visual_gen_params.py @@ -53,13 +53,18 @@ def test_default_construction(self): assert params.frame_rate is None assert params.negative_prompt is None assert params.image is None - assert params.mask is None - assert params.image_cond_strength is None + # ``image_cond_strength`` moved to per-pipeline ``extra_params`` + # (only LTX-2 consumes it). It is no longer a top-level field. + assert not hasattr(params, "image_cond_strength") + # `seed` is now ``Optional[int]`` and defaults to None — the engine + # draws a fresh value on the coordinator rank before broadcast. + assert params.seed is None # Concrete defaults - assert params.seed == 42 assert params.num_images_per_prompt == 1 # Extra params assert params.extra_params is None + # The model does not expose a ``mask`` field. + assert not hasattr(params, "mask") def test_explicit_values(self): from tensorrt_llm.visual_gen import VisualGenParams @@ -125,6 +130,17 @@ def test_negative_prompt_on_params(self): params = VisualGenParams(negative_prompt="blurry, low quality") assert params.negative_prompt == "blurry, low quality" + def test_seed_accepts_int64_range(self): + """The Python API does not clamp the seed — only the serve + boundary (openai_protocol request schemas) enforces the + OpenAI DALL-E UINT32 range. ``VisualGenParams.seed`` accepts + any int that ``torch.Generator`` supports.""" + from tensorrt_llm.visual_gen import VisualGenParams + + assert VisualGenParams(seed=0).seed == 0 + # Above the UINT32 boundary — accepted at the Python API. + assert VisualGenParams(seed=2**40).seed == 2**40 + # ============================================================================= # ExtraParamSchema @@ -285,6 +301,7 @@ def test_ltx2_extra_specs(self): expected_keys = { "output_type", "guidance_rescale", + "image_cond_strength", "stg_scale", "stg_blocks", "modality_scale", @@ -403,19 +420,19 @@ def test_all_declared_keys_present_after_merge(self): for key in ltx2_specs: assert key in req.params.extra_params, f"Missing key: {key}" - def test_params_none_materializes_defaults(self): - """req.params=None is the default path from generate_async(params=None); - _merge_defaults should materialize a VisualGenParams from pipeline defaults.""" + def test_default_params_materialize_pipeline_defaults(self): + """A fresh, all-None VisualGenParams (what VisualGen.generate_async + builds when the caller passes ``params=None``) should pick up + every pipeline default after ``_merge_defaults``.""" from tensorrt_llm._torch.visual_gen.executor import DiffusionRequest from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline from tensorrt_llm.visual_gen.params import VisualGenParams executor = self._make_mock_executor(LTX2Pipeline) - req = DiffusionRequest(request_id=0, prompt=["test"], params=None) + req = DiffusionRequest(request_id=0, prompt=["test"], params=VisualGenParams()) self._merge(executor, req) - assert isinstance(req.params, VisualGenParams) # Universal defaults are filled from the pipeline assert req.params.height == 512 assert req.params.width == 768 @@ -464,7 +481,9 @@ def test_ltx2_default_params(self): assert params.height == 512 assert params.width == 768 assert params.num_inference_steps == 40 - assert params.seed == 42 + # Pipelines don't declare a seed default; the executor resolves + # ``None`` to a concrete integer on the coordinator rank. + assert params.seed is None assert params.extra_params is not None assert params.extra_params["stg_scale"] == 0.0 assert params.extra_params["output_type"] == "pt" @@ -668,12 +687,17 @@ def test_client_extracts_metadata_from_ready(self): # ============================================================================= -# Request validation — _validate_request +# Request validation — validate_visual_gen_params # ============================================================================= class TestRequestValidation: - """DiffusionExecutor._validate_request raises ValueError on bad params.""" + """``validate_visual_gen_params`` raises ``ValueError`` on bad params. + + The validator is now called on the coordinator side at + :meth:`VisualGen.generate_async` entry; these tests call it directly + against the pipeline's declared defaults / extra-param specs. + """ def _make_mock_executor(self, pipeline_cls, mock_self=None): executor = MagicMock() @@ -692,15 +716,24 @@ def _make_request(self, **kwargs): return DiffusionRequest(request_id=0, prompt=["test"], params=VisualGenParams(**kwargs)) def _validate(self, executor, req): - from tensorrt_llm._torch.visual_gen.executor import DiffusionExecutor + from tensorrt_llm.visual_gen.params import validate_visual_gen_params - DiffusionExecutor._validate_request(executor, req) + validate_visual_gen_params( + req.params, + declared_defaults=executor.pipeline.default_generation_params, + extra_param_specs=executor.pipeline.extra_param_specs, + ) def _merge_and_validate(self, executor, req): from tensorrt_llm._torch.visual_gen.executor import DiffusionExecutor + from tensorrt_llm.visual_gen.params import validate_visual_gen_params DiffusionExecutor._merge_defaults(executor, req) - DiffusionExecutor._validate_request(executor, req) + validate_visual_gen_params( + req.params, + declared_defaults=executor.pipeline.default_generation_params, + extra_param_specs=executor.pipeline.extra_param_specs, + ) # --- unknown extra_params --- @@ -735,7 +768,7 @@ def test_num_frames_on_image_pipeline_raises(self): executor = self._make_mock_executor(FluxPipeline) req = self._make_request(num_frames=81) - with pytest.raises(ValueError, match="num_frames.*not use it"): + with pytest.raises(ValueError, match="num_frames.*not accept it"): self._validate(executor, req) def test_frame_rate_on_image_pipeline_raises(self): @@ -743,7 +776,27 @@ def test_frame_rate_on_image_pipeline_raises(self): executor = self._make_mock_executor(FluxPipeline) req = self._make_request(frame_rate=24.0) - with pytest.raises(ValueError, match="frame_rate.*not use it"): + with pytest.raises(ValueError, match="frame_rate.*not accept it"): + self._validate(executor, req) + + def test_image_cond_strength_on_ltx2_extra_params_ok(self): + """LTX-2 declares ``image_cond_strength`` in extra_param_specs; + passing it via ``extra_params`` must validate successfully.""" + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + executor = self._make_mock_executor(LTX2Pipeline) + req = self._make_request(extra_params={"image_cond_strength": 0.6}) + self._merge_and_validate(executor, req) # should not raise + + def test_image_cond_strength_on_wan_via_extra_params_raises(self): + """Wan pipelines do not declare ``image_cond_strength`` in + their extra_param_specs, so passing it via ``extra_params`` + must be rejected as an unknown key.""" + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + executor = self._make_mock_executor(WanPipeline, _wan_mock(num_heads=12)) + req = self._make_request(extra_params={"image_cond_strength": 0.8}) + with pytest.raises(ValueError, match="Unknown extra_params"): self._validate(executor, req) def test_image_not_checked_by_validator(self): @@ -781,19 +834,18 @@ def test_none_fields_not_flagged(self): req = self._make_request() # all None self._merge_and_validate(executor, req) - def test_params_none_merge_and_validate_ok(self): - """req.params=None must merge + validate cleanly (VisualGen.generate_async - defaults to params=None, so this is the canonical call path).""" + def test_default_params_merge_and_validate_ok(self): + """A fresh ``VisualGenParams()`` (what the enqueue site builds when + the caller passes ``params=None``) must merge + validate cleanly.""" from tensorrt_llm._torch.visual_gen.executor import DiffusionRequest from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline from tensorrt_llm.visual_gen.params import VisualGenParams executor = self._make_mock_executor(LTX2Pipeline) - req = DiffusionRequest(request_id=0, prompt=["test"], params=None) + req = DiffusionRequest(request_id=0, prompt=["test"], params=VisualGenParams()) self._merge_and_validate(executor, req) # should not raise - assert isinstance(req.params, VisualGenParams) assert req.params.height == 512 assert req.params.extra_params["stg_scale"] == 0.0 @@ -911,37 +963,216 @@ def test_none_extra_param_value_skipped(self): req = self._make_request(extra_params={"boundary_ratio": None}) self._merge_and_validate(executor, req) - # --- process_request returns error response instead of crashing --- - def test_process_request_returns_error_on_validation_failure(self): - """Validation errors become error responses, not server crashes.""" - from tensorrt_llm._torch.visual_gen.executor import DiffusionExecutor, DiffusionResponse +# ============================================================================= +# Parameter validation — message content per category +# ============================================================================= - # Build a mock with real method bindings for the three methods - # that process_request chains through. + +class TestValidateVisualGenParamsMessages: + """``validate_visual_gen_params`` raises ``ValueError`` with a multi-line + message naming every offending field so callers (and HTTP clients) can + fix the request without parsing a structured envelope.""" + + def _make_mock_executor(self, pipeline_cls, mock_self=None): executor = MagicMock() executor.pipeline = MagicMock() - executor.pipeline.__class__.__name__ = "FluxPipeline" - executor.pipeline.default_generation_params = {"height": 1024, "width": 1024} - executor.pipeline.extra_param_specs = {} - executor.pipeline._warmed_up_shapes = set() - executor.pipeline.warmup_cache_key = MagicMock(return_value=(1024, 1024, None)) + executor.pipeline.__class__ = pipeline_cls + executor.pipeline.default_generation_params = pipeline_cls.default_generation_params.fget( + mock_self + ) + executor.pipeline.extra_param_specs = pipeline_cls.extra_param_specs.fget(mock_self) + return executor + + def _make_request(self, **kwargs): + from tensorrt_llm._torch.visual_gen.executor import DiffusionRequest + from tensorrt_llm.visual_gen.params import VisualGenParams + + return DiffusionRequest(request_id=0, prompt=["test"], params=VisualGenParams(**kwargs)) + + def _validate(self, executor, req): + from tensorrt_llm.visual_gen.params import validate_visual_gen_params + + validate_visual_gen_params( + req.params, + declared_defaults=executor.pipeline.default_generation_params, + extra_param_specs=executor.pipeline.extra_param_specs, + ) + + def test_unknown_extra_param_message(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + executor = self._make_mock_executor(LTX2Pipeline) + req = self._make_request(extra_params={"stg_sclae": 1.0, "bogus_key": 2}) + with pytest.raises(ValueError) as excinfo: + self._validate(executor, req) + msg = str(excinfo.value) + assert "Parameter validation failed" in msg + assert "Unknown extra_params" in msg + assert "bogus_key" in msg and "stg_sclae" in msg + + def test_unsupported_universal_field_message(self): + """An image pipeline should reject video-only universal fields and + name every offending field in the message.""" + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux import FluxPipeline + + executor = self._make_mock_executor(FluxPipeline) + req = self._make_request(num_frames=81, frame_rate=24.0) + with pytest.raises(ValueError) as excinfo: + self._validate(executor, req) + msg = str(excinfo.value) + assert "num_frames" in msg + assert "frame_rate" in msg + assert "does not accept it" in msg + + def test_extra_param_type_mismatch_message(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + executor = self._make_mock_executor(LTX2Pipeline) + req = self._make_request(extra_params={"stg_scale": "fast"}) + with pytest.raises(ValueError) as excinfo: + self._validate(executor, req) + msg = str(excinfo.value) + assert "stg_scale" in msg + assert "expected type 'float'" in msg + assert "got str" in msg + + def test_extra_param_out_of_range_message(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + executor = self._make_mock_executor(WanPipeline, _wan_mock(is_wan22_14b=True, num_heads=12)) + req = self._make_request(extra_params={"boundary_ratio": -0.5}) + with pytest.raises(ValueError) as excinfo: + self._validate(executor, req) + msg = str(excinfo.value) + assert "boundary_ratio" in msg + assert "-0.5" in msg + assert "[0.0, 1.0]" in msg + + +# ============================================================================= +# Seed resolution — coordinator-rank materialization +# ============================================================================= + + +class TestResolveSeed: + """``VisualGen.generate_async`` materializes ``params.seed`` once on the + coordinator process, so the request that travels over ZMQ already + carries a concrete int and rank-0's broadcast propagates the same + value to every rank.""" + + def _make_visual_gen(self): + """Build a minimal ``VisualGen`` shim that exposes ``generate_async`` + without spinning up the worker process.""" + import itertools + + from tensorrt_llm.visual_gen.visual_gen import VisualGen + + executor = MagicMock() + executor.default_generation_params = {} + executor.extra_param_specs = {} + executor.enqueue_requests = MagicMock() + + vg = VisualGen.__new__(VisualGen) + vg.executor = executor + vg._req_counter = itertools.count() + return vg + + def _enqueued_request(self, vg): + vg.executor.enqueue_requests.assert_called_once() + return vg.executor.enqueue_requests.call_args[0][0][0] + + def test_seed_none_is_materialized(self): + from tensorrt_llm.visual_gen import VisualGenParams + + vg = self._make_visual_gen() + vg.generate_async("x", params=VisualGenParams()) + req = self._enqueued_request(vg) + assert isinstance(req.params.seed, int) + assert 0 <= req.params.seed < (1 << 63) + + def test_concrete_seed_preserved(self): + from tensorrt_llm.visual_gen import VisualGenParams + + vg = self._make_visual_gen() + vg.generate_async("x", params=VisualGenParams(seed=12345)) + req = self._enqueued_request(vg) + assert req.params.seed == 12345 + + def test_two_calls_draw_two_distinct_seeds(self): + """Each request gets its own random seed when None is sent.""" + from tensorrt_llm.visual_gen import VisualGenParams + + vg = self._make_visual_gen() + vg.generate_async("x", params=VisualGenParams()) + vg.generate_async("y", params=VisualGenParams()) + calls = vg.executor.enqueue_requests.call_args_list + seed_a = calls[0][0][0][0].params.seed + seed_b = calls[1][0][0][0].params.seed + # Probabilistic — collision space is 2**63; essentially impossible. + assert seed_a != seed_b + + def test_caller_params_not_mutated(self): + """Resolution operates on the deep-copied snapshot, not the caller's + original ``VisualGenParams`` instance.""" + from tensorrt_llm.visual_gen import VisualGenParams + + vg = self._make_visual_gen() + caller_params = VisualGenParams() + vg.generate_async("x", params=caller_params) + assert caller_params.seed is None + assert isinstance(self._enqueued_request(vg).params.seed, int) + + +# ============================================================================= +# DiffusionResponse — engine-failure transport +# ============================================================================= + + +class TestEngineFailureTransport: + """Validation is enforced at :meth:`VisualGen.generate_async` entry, so + by the time a request reaches ``process_request`` only runtime + failures from ``pipeline.infer()`` can produce an error response. + The error message rides back on ``DiffusionResponse.error_msg``. + """ + + def _make_executor(self, pipeline_cls, mock_self=None): + executor = MagicMock() executor.rank = 0 executor.device_id = 0 executor.response_queue = MagicMock() + executor.pipeline = MagicMock() + executor.pipeline.__class__ = pipeline_cls + executor.pipeline.default_generation_params = pipeline_cls.default_generation_params.fget( + mock_self + ) + executor.pipeline.extra_param_specs = pipeline_cls.extra_param_specs.fget(mock_self) + return executor - # Wire real methods onto the mock so process_request uses them + def test_runtime_error_carried_on_response(self): + from tensorrt_llm._torch.visual_gen.executor import ( + DiffusionExecutor, + DiffusionRequest, + DiffusionResponse, + ) + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux import FluxPipeline + from tensorrt_llm.visual_gen.params import VisualGenParams + + executor = self._make_executor(FluxPipeline) executor._merge_defaults = lambda req: DiffusionExecutor._merge_defaults(executor, req) - executor._validate_request = lambda req: DiffusionExecutor._validate_request(executor, req) + executor.pipeline.warmup_cache_key = MagicMock(return_value=(1024, 1024, None)) + executor.pipeline._warmed_up_shapes = None + executor.pipeline.infer = MagicMock(side_effect=RuntimeError("oops")) - req = self._make_request(num_frames=81, extra_params={"bad": 1}) + req = DiffusionRequest( + request_id=7, + prompt=["test"], + params=VisualGenParams(), + ) - # Call the real process_request DiffusionExecutor.process_request(executor, req) - # Should have put an error response, not crashed executor.response_queue.put.assert_called_once() resp = executor.response_queue.put.call_args[0][0] assert isinstance(resp, DiffusionResponse) - assert resp.error_msg is not None - assert "validation failed" in resp.error_msg.lower() + assert resp.error_msg == "oops" diff --git a/tests/unittest/_torch/visual_gen/test_visual_gen_utils.py b/tests/unittest/_torch/visual_gen/test_visual_gen_utils.py new file mode 100644 index 000000000000..4b4836392a1f --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_visual_gen_utils.py @@ -0,0 +1,343 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Direct tests for :mod:`tensorrt_llm.serve.visual_gen_utils`. + +These tests bypass the HTTP transport and call ``parse_visual_gen_params`` +and the ``_warn_if_set_with_no_semantic`` / ``_merge_extra_params`` +helpers directly against constructed Pydantic request objects and a +stub :class:`VisualGen`. They cover the ``extra_params`` merge truth +table plus the field-by-field overlay contract. +""" + +from __future__ import annotations + +import base64 +from io import BytesIO +from typing import Any, Dict, Optional + +import pytest +from PIL import Image + +from tensorrt_llm.serve.openai_protocol import ImageGenerationRequest, VideoGenerationRequest +from tensorrt_llm.serve.visual_gen_utils import ( + _merge_extra_params, + _warn_if_set_with_no_semantic, + parse_visual_gen_params, +) +from tensorrt_llm.visual_gen import VisualGenParams + + +class _StubExtraParamSpec: + def __init__(self, default: Any = None) -> None: + self.default = default + + +class _StubVisualGen: + """Minimal :class:`VisualGen` stand-in for direct conversion tests. + + The conversion layer only reads ``default_params``, ``model``, and + ``extra_param_specs`` — populate those directly. + """ + + def __init__( + self, + defaults: Optional[Dict[str, Any]] = None, + extra_param_specs: Optional[Dict[str, Any]] = None, + model: str = "stub", + ) -> None: + self._defaults = defaults or {} + self.extra_param_specs = extra_param_specs or {} + self.model = model + + @property + def default_params(self) -> VisualGenParams: + # Always return a fresh instance so the conversion layer can + # mutate it without leaking across tests. + return VisualGenParams(**self._defaults) + + +@pytest.fixture +def image_request_defaults(): + return ImageGenerationRequest(prompt="cat", response_format="b64_json") + + +@pytest.fixture +def video_request_defaults(): + return VideoGenerationRequest(prompt="storm", response_format="b64_json") + + +# ============================================================================= +# Default overlay — only client-sent fields override pipeline defaults +# ============================================================================= + + +class TestDefaultOverlay: + def test_all_none_request_keeps_pipeline_defaults(self, image_request_defaults): + generator = _StubVisualGen( + defaults={"width": 1024, "height": 1024, "num_inference_steps": 30}, + ) + params = parse_visual_gen_params(image_request_defaults, "id-1", generator) + assert params.width == 1024 + assert params.height == 1024 + assert params.num_inference_steps == 30 + + def test_image_explicit_fields_override_defaults(self): + generator = _StubVisualGen( + defaults={"width": 1024, "height": 1024, "num_inference_steps": 30}, + ) + request = ImageGenerationRequest( + prompt="cat", + width=512, + height=512, + num_inference_steps=10, + guidance_scale=4.0, + max_sequence_length=128, + seed=99, + n=4, + negative_prompt="blurry", + ) + params = parse_visual_gen_params(request, "id-2", generator) + assert (params.width, params.height) == (512, 512) + assert params.num_inference_steps == 10 + assert params.guidance_scale == 4.0 + assert params.max_sequence_length == 128 + assert params.seed == 99 + assert params.num_images_per_prompt == 4 + assert params.negative_prompt == "blurry" + + def test_size_string_used_when_width_height_absent(self): + generator = _StubVisualGen() + request = ImageGenerationRequest(prompt="cat", size="768x256") + params = parse_visual_gen_params(request, "id-3", generator) + assert (params.width, params.height) == (768, 256) + + def test_width_height_pair_wins_over_size(self): + generator = _StubVisualGen() + request = ImageGenerationRequest(prompt="cat", size="768x256", width=128, height=64) + params = parse_visual_gen_params(request, "id-4", generator) + assert (params.width, params.height) == (128, 64) + + def test_image_seed_propagates(self): + generator = _StubVisualGen() + request = ImageGenerationRequest(prompt="cat", seed=12345) + params = parse_visual_gen_params(request, "id-seed", generator) + assert params.seed == 12345 + + +# ============================================================================= +# Seed range clamp on the serve boundary +# ============================================================================= + + +class TestSeedLowerBoundOnServeBoundary: + """Negative seeds are rejected at the HTTP request schema; the rest + of the int64 range is accepted, matching what the underlying + ``torch.Generator.manual_seed`` supports. + """ + + def test_image_seed_accepts_zero_and_large_values(self): + from tensorrt_llm.serve.openai_protocol import ImageGenerationRequest + + assert ImageGenerationRequest(prompt="x", seed=0).seed == 0 + large = 2**40 + assert ImageGenerationRequest(prompt="x", seed=large).seed == large + + def test_image_seed_rejects_negative(self): + from pydantic import ValidationError + + from tensorrt_llm.serve.openai_protocol import ImageGenerationRequest + + with pytest.raises(ValidationError): + ImageGenerationRequest(prompt="x", seed=-1) + + def test_video_seed_rejects_negative(self): + from pydantic import ValidationError + + from tensorrt_llm.serve.openai_protocol import VideoGenerationRequest + + with pytest.raises(ValidationError): + VideoGenerationRequest(prompt="x", seed=-1) + + +# ============================================================================= +# OpenAI-shape "warn-on-set" fields +# ============================================================================= + + +class TestWarnOnSet: + """The TRT-LLM logger doesn't propagate through Python's root logger, + so these tests monkeypatch :func:`logger.warning` directly and + inspect what the helper would have emitted.""" + + def _capture_warnings(self, monkeypatch): + captured: list[str] = [] + + def _fake_warning(msg: str, *args: object, **kwargs: object) -> None: + try: + rendered = msg % args if args else msg + except (TypeError, ValueError): + rendered = str(msg) + captured.append(rendered) + + from tensorrt_llm.serve import visual_gen_utils as vgu + + monkeypatch.setattr(vgu.logger, "warning", _fake_warning) + return captured + + def test_quality_hd_does_not_override_steps(self): + generator = _StubVisualGen(defaults={"num_inference_steps": 25}) + request = ImageGenerationRequest(prompt="cat", quality="hd") + params = parse_visual_gen_params(request, "id-q", generator) + # ``quality`` is an OpenAI-shape no-semantic field. The pipeline + # default for ``num_inference_steps`` must reach the engine + # unchanged. + assert params.num_inference_steps == 25 + + def test_style_set_logs_warning(self, monkeypatch): + captured = self._capture_warnings(monkeypatch) + request = ImageGenerationRequest(prompt="cat", style="vivid") + _warn_if_set_with_no_semantic(request, "stub") + assert any("'style'" in m for m in captured) + + def test_user_set_does_not_log_warning(self, monkeypatch): + captured = self._capture_warnings(monkeypatch) + request = ImageGenerationRequest(prompt="cat", user="abc") + _warn_if_set_with_no_semantic(request, "stub") + assert not any("'user'" in m for m in captured) + + def test_model_mismatch_logs_warning(self, monkeypatch): + captured = self._capture_warnings(monkeypatch) + request = ImageGenerationRequest(prompt="cat", model="some-other") + _warn_if_set_with_no_semantic(request, "flux2") + assert any("'model'" in m for m in captured) + + def test_model_match_does_not_log_warning(self, monkeypatch): + captured = self._capture_warnings(monkeypatch) + request = ImageGenerationRequest(prompt="cat", model="flux2") + _warn_if_set_with_no_semantic(request, "flux2") + assert not any("'model'" in m for m in captured) + + +# ============================================================================= +# Video frame-budget derivation +# ============================================================================= + + +class TestVideoFrameBudget: + def test_num_frames_wins_over_seconds_times_frame_rate(self): + generator = _StubVisualGen(defaults={"frame_rate": 24.0}) + request = VideoGenerationRequest(prompt="x", num_frames=33, seconds=10.0) + params = parse_visual_gen_params(request, "id-v1", generator) + assert params.num_frames == 33 + + def test_seconds_and_frame_rate_derive_num_frames(self): + generator = _StubVisualGen(defaults={"frame_rate": 12.0}) + # fps alias resolves to frame_rate via populate_by_name=True + request = VideoGenerationRequest(prompt="x", seconds=2.5, fps=24) + params = parse_visual_gen_params(request, "id-v2", generator) + assert params.frame_rate == 24.0 + assert params.num_frames == int(2.5 * 24.0) + + def test_seconds_alone_uses_pipeline_frame_rate(self): + generator = _StubVisualGen(defaults={"frame_rate": 16.0}) + request = VideoGenerationRequest(prompt="x", seconds=4.0) + params = parse_visual_gen_params(request, "id-v3", generator) + assert params.frame_rate == 16.0 + assert params.num_frames == int(4.0 * 16.0) + + def test_video_does_not_carry_n(self): + generator = _StubVisualGen() + # Video request has no ``n`` field — Pydantic rejects it at + # schema time, but constructing the request without it must + # leave ``num_images_per_prompt`` unchanged from the pipeline + # default. + request = VideoGenerationRequest(prompt="x") + params = parse_visual_gen_params(request, "id-v4", generator) + assert params.num_images_per_prompt == 1 + + +# ============================================================================= +# input_reference materialization +# ============================================================================= + + +class TestInputReferenceMaterialization: + def test_base64_reference_written_to_disk(self, tmp_path): + generator = _StubVisualGen() + img = Image.new("RGB", (4, 4), (10, 20, 30)) + buf = BytesIO() + img.save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode() + request = VideoGenerationRequest(prompt="x", input_reference=b64) + params = parse_visual_gen_params( + request, "vid-1", generator, media_storage_path=str(tmp_path) + ) + assert params.image is not None + assert str(params.image).endswith("vid-1_reference.png") + # The decoded image is identical to what we passed in. + with open(params.image, "rb") as f: + decoded = Image.open(f).convert("RGB") + assert decoded.size == (4, 4) + + def test_missing_media_storage_path_raises(self): + generator = _StubVisualGen() + img = Image.new("RGB", (2, 2)) + buf = BytesIO() + img.save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode() + request = VideoGenerationRequest(prompt="x", input_reference=b64) + with pytest.raises(ValueError, match="media_storage_path"): + parse_visual_gen_params(request, "vid-2", generator, media_storage_path=None) + + +# ============================================================================= +# _merge_extra_params — the merge truth table +# ============================================================================= + + +class TestMergeExtraParams: + def _make_params(self, defaults: Optional[Dict[str, Any]] = None) -> VisualGenParams: + return VisualGenParams(extra_params=dict(defaults) if defaults else None) + + def test_omitted_key_keeps_default(self): + specs = {"stg_scale": _StubExtraParamSpec(default=1.0)} + params = self._make_params({"stg_scale": 1.0}) + _merge_extra_params(params, request_extras=None, extra_param_specs=specs) + assert params.extra_params == {"stg_scale": 1.0} + + def test_known_non_null_overrides_default(self): + specs = {"stg_scale": _StubExtraParamSpec(default=1.0)} + params = self._make_params({"stg_scale": 1.0}) + _merge_extra_params(params, {"stg_scale": 2.5}, specs) + assert params.extra_params["stg_scale"] == 2.5 + + def test_known_null_keeps_default(self): + """Schema-aware null sentinel: ``{"stg_scale": null}`` does not + clear the pre-seeded pipeline default and does not pass through + to the executor as ``None`` either.""" + specs = {"stg_scale": _StubExtraParamSpec(default=1.0)} + params = self._make_params({"stg_scale": 1.0}) + _merge_extra_params(params, {"stg_scale": None}, specs) + assert params.extra_params["stg_scale"] == 1.0 + + def test_unknown_key_passes_through_with_value(self): + """Unknown keys are preserved verbatim so the executor's + strict-key validator raises ``unknown_extra_param``.""" + specs = {"stg_scale": _StubExtraParamSpec(default=1.0)} + params = self._make_params({"stg_scale": 1.0}) + _merge_extra_params(params, {"stg_sclae": 9.9}, specs) + assert params.extra_params == {"stg_scale": 1.0, "stg_sclae": 9.9} + + def test_unknown_key_with_null_passes_through(self): + """Critical: unknown + null is *not* stripped. A schema-blind + "drop every null" rule would let typos like ``{"stg_sclae": + null}`` reach the engine as a silent no-op.""" + specs = {"stg_scale": _StubExtraParamSpec(default=1.0)} + params = self._make_params({"stg_scale": 1.0}) + _merge_extra_params(params, {"stg_sclae": None}, specs) + assert params.extra_params["stg_sclae"] is None + + def test_empty_extras_dict_normalizes_to_none(self): + params = self._make_params() + _merge_extra_params(params, request_extras=None, extra_param_specs={}) + assert params.extra_params is None diff --git a/tests/unittest/media/test_encoding.py b/tests/unittest/media/test_encoding.py index 3e380955b84f..ab2c65603ff4 100644 --- a/tests/unittest/media/test_encoding.py +++ b/tests/unittest/media/test_encoding.py @@ -78,10 +78,23 @@ def test_image_to_bytes_returns_nonempty_png(): assert img.format == "PNG" -def test_save_image_strips_batch_dim(tmp_path): - """save_image accepts (B, H, W, C) and writes the first slice.""" +def test_save_image_rejects_batched_tensor_size_gt_1(tmp_path): + """save_image raises ValueError for (B>1, H, W, C) tensors. + + The single-path API requires the caller to disambiguate when the + tensor carries a real batch axis; see :func:`save_images` for the + multi-path fan-out. + """ batched = torch.stack([_dummy_image(), _dummy_image(), _dummy_image()]) - target = tmp_path / "first.png" + target = tmp_path / "out.png" + with pytest.raises(ValueError, match="batched tensor of size 3"): + save_image(batched, target) + + +def test_save_image_accepts_batch_size_1(tmp_path): + """save_image accepts (1, H, W, C) — the leading axis is unwrapped.""" + batched = _dummy_image().unsqueeze(0) + target = tmp_path / "single.png" saved = save_image(batched, target) assert Path(saved).exists() img = Image.open(saved) From dab34008572fb05e7b52f3c9947d1e9a80693dc9 Mon Sep 17 00:00:00 2001 From: Dhinesh Ponnarasan <160256912+DhineshPonnarasan@users.noreply.github.com> Date: Wed, 10 Jun 2026 11:47:11 -0400 Subject: [PATCH 4/6] [None][test] Add MLA chunked-prefill SM dispatch regression coverage (#13904) Signed-off-by: Dhinesh Ponnarasan --- .../_torch/attention/test_attention_mla.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/unittest/_torch/attention/test_attention_mla.py b/tests/unittest/_torch/attention/test_attention_mla.py index bd01ed363dc4..7d0d118b6542 100644 --- a/tests/unittest/_torch/attention/test_attention_mla.py +++ b/tests/unittest/_torch/attention/test_attention_mla.py @@ -379,6 +379,78 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: } +@pytest.mark.parametrize( + "sm_version,expected_path", + [ + (90, "cached_kv"), + (99, "cached_kv"), + (100, "chunked_prefill"), + ], +) +def test_mla_chunked_prefill_dispatch_by_sm(sm_version, expected_path, + monkeypatch): + import tensorrt_llm._torch.modules.attention as attention_module + + class FakeTrtllmAttention: + + @staticmethod + def has_cached_kv_for_mla_context_warmup(_metadata): + return False + + @staticmethod + def is_chunked_prefill_for_mla_context(_metadata): + return True + + @staticmethod + def has_cached_kv_for_mla_context(_metadata): + return False + + class FakeMetadata: + pass + + class FakeAttention: + + def __init__(self): + self.mha = FakeTrtllmAttention() + + @staticmethod + def forward_context_with_chunked_prefill(*_args, **_kwargs): + return "chunked_prefill" + + @staticmethod + def forward_context_with_cached_kv(*_args, **_kwargs): + return "cached_kv" + + @staticmethod + def forward_context_default(*_args, **_kwargs): + return "default" + + monkeypatch.setattr(attention_module, "TrtllmAttention", + FakeTrtllmAttention) + monkeypatch.setattr(attention_module, "TrtllmAttentionMetadata", + FakeMetadata) + monkeypatch.setattr(attention_module, "get_sm_version", lambda: sm_version) + + q = torch.empty((1, 8), dtype=torch.float16) + compressed_kv = torch.empty((1, 4), dtype=torch.float16) + k_pe = torch.empty((1, 4), dtype=torch.float16) + position_ids = torch.zeros((1, ), dtype=torch.int64) + output = torch.empty((1, 8), dtype=torch.float16) + latent_cache = torch.empty((1, 1, 8), dtype=torch.float16) + + result = attention_module.MLA.forward_context( + FakeAttention(), + q, + compressed_kv, + k_pe, + position_ids, + FakeMetadata(), + output, + latent_cache, + ) + assert result == expected_path + + # Convert parameterized tests to pytest parametrize @pytest.mark.parametrize("scenario", scenarios, ids=lambda x: f"scenario: {x}") @pytest.mark.parametrize("context_sequence_lengths", From 0be1447c71c7bdbceac6060984dfd2cfd7cc36e3 Mon Sep 17 00:00:00 2001 From: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Date: Wed, 10 Jun 2026 09:06:47 -0700 Subject: [PATCH 5/6] [TRTLLM-12648][test] enable disagg cancellation stress test (#15174) Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> --- .../defs/stress_test/disagg_cancel/README.md | 173 ++++-- .../disagg_cancel/configs/README.md | 22 +- .../configs/marathon_cpp_v1_deepseek.yaml | 25 +- .../configs/marathon_python_v2_qwen.yaml | 5 + .../defs/stress_test/disagg_cancel/harness.py | 493 ++++++++++++++++-- .../test_disagg_cancel_stress.py | 143 +++-- .../test_lists/qa/llm_function_stress.txt | 1 + 7 files changed, 757 insertions(+), 105 deletions(-) diff --git a/tests/integration/defs/stress_test/disagg_cancel/README.md b/tests/integration/defs/stress_test/disagg_cancel/README.md index 8cc9650a8bf3..26ef1b83d780 100644 --- a/tests/integration/defs/stress_test/disagg_cancel/README.md +++ b/tests/integration/defs/stress_test/disagg_cancel/README.md @@ -1,6 +1,6 @@ # Disaggregated Cancellation Stress-Test Suite -Marathon-style stress tests that gate regressions of the bug class +Disaggregated stress tests that gate regressions of the bug class fixed by (cleanup / lifetime / quiescence invariants in the disagg KV transceiver under heavy mid-flight cancellation). @@ -13,8 +13,22 @@ transceiver under heavy mid-flight cancellation). ## Status -The harness class structure and lifecycle are in place. Thread bodies -land incrementally: +The registered QA stress entry now launches a real C++/V1 DeepSeek +disaggregated cluster in `log_only` mode. That mode sends normal +non-cancel completion probes through the front-end and scans saved +worker/server logs for UAF, broken-promise, and segmentation-fault +signatures. It is intentionally narrow so it can run regularly before +in-flight cancellation and poison-buffer hardening are available. + +The full cancellation/poison marathon is implemented as an explicit +mode switch, but it is not the registered default yet. + +| Mode | CI status | Threads | Coverage | +|------|-----------|---------|----------| +| `log_only` | Registered in `qa/llm_function_stress.txt` | log-only probe + log scanner | startup/data-path crash guard: UAF, broken promise, segfault-class logs | +| `full_cancel_poison` | Opt-in only | load, canary, injector, log scanner, metrics | cancellation load, failure injection, poison canaries, KV-growth guard | + +Thread bodies: | Thread | Status | |--------|--------| @@ -26,8 +40,8 @@ land incrementally: Component-level coverage: `test_log_scanner.py`, `test_metrics_thread.py`, `test_injector.py`, `test_canary.py`, `test_load_thread.py`. The -parametrized marathon pytest still runs a lifecycle smoke until -`setup()` launches a real cluster. +parametrized C++/V1 DeepSeek run is registered in the QA stress test +list as a real `log_only` guardrail. ## File layout @@ -52,17 +66,76 @@ Future additions: - `tools/generate_canary_references.py` — one-shot reference generator that records greedy-decode token IDs for the canary prompts. - `configs/stress_canary_prompts.json` — canary prompts + recorded - reference token IDs (consumed by the canary thread). + reference token IDs for `full_cancel_poison`. - Per-scenario YAMLs covering additional axes: 1P1D, 4P2D, V1+Python, UCX, block-reuse-off, overlap-off, aggressive-timeout, multi-node (all Python-only test-side configuration). +## Mode Switch + +The active mode is controlled by +`configs/marathon_cpp_v1_deepseek.yaml`: + +```yaml +stress_config: + mode: log_only + duration_min: 10 +``` + +Use `log_only` for regular CI until both runtime features are in +place: + +- in-flight request cancellation support for the disaggregated path. +- poison-buffer hardening that makes poisoned cache transfers + expected and recoverable. + +To switch to the full cancellation/poison marathon after those +features are ready: + +1. Set `stress_config.mode: full_cancel_poison`. +2. Set `stress_config.duration_min: 120` for the two-hour marathon. +3. Keep or tune `base_concurrency`, `bursts`, and `injections`. +4. Add `configs/stress_canary_prompts.json` with token references and + keep `canary.check_token_equivalent: true`. +5. Add the poison-buffer hard-zero/expected-recovery patterns that + match the finalized runtime behavior. +6. Raise the test-list timeout back to a full-marathon budget, e.g. + `TIMEOUT (150)`. + ## How to run -The marathons are **not** registered in pre-merge CI. They are run -nightly / weekly via -`tests/integration/test_lists/qa/llm_function_stress.txt` (wiring -lands with the explicit CI-registration change). +### Scheduled QA stress run + +The C++/V1 DeepSeek marathon is registered in +`tests/integration/test_lists/qa/llm_function_stress.txt`, which makes +it eligible for the QA/Jenkins job that consumes that stress list. This +PR does not create or modify the scheduler for that job; the exact +cadence and wall-clock start time are owned by QA CI configuration +outside this directory. The in-repo QA README describes QA lists as +regular daily/release and weekly/release/on-demand coverage, but does +not define a file-specific cadence for `llm_function_stress.txt`. + +The registered entry is: + +```text +stress_test/disagg_cancel/test_disagg_cancel_stress.py::test_disagg_cancellation_marathon[marathon_cpp_v1_deepseek.yaml] TIMEOUT (45) +``` + +The integration test-list parser interprets `TIMEOUT (45)` in +minutes. CI should run the list from `tests/integration/defs` with: + +```bash +pytest --test-list=../test_lists/qa/llm_function_stress.txt \ + --output-dir= \ + -s -v +``` + +The scheduled runner must use the normal TRT-LLM integration container +or virtual environment with GPU access, `trtllm-serve` on `PATH`, and +`LLM_MODELS_ROOT` set so `DeepSeek-V3-Lite/bf16` resolves to local +model weights. The current registered run is `log_only`: setup can +take up to 20 minutes, then the harness probes for 10 minutes and +tails worker/server logs. ### Unit tests (no GPU, no cluster) @@ -105,12 +178,13 @@ python3 -m pytest -c /dev/null -o addopts= \ tests/integration/defs/stress_test/disagg_cancel/test_disagg_cancel_stress.py::test_all_marathon_yamls_parse_and_validate -v ``` -All component tests together: +All component tests together, excluding the real cluster entry: ```bash python3 -m pytest -c /dev/null -o addopts= \ --confcutdir=tests/integration/defs/stress_test \ - tests/integration/defs/stress_test/disagg_cancel/ -q + tests/integration/defs/stress_test/disagg_cancel/ \ + -k "not test_disagg_cancellation_marathon" -q ``` In a full TRT-LLM dev container/venv (with `transformers` installed), @@ -120,27 +194,63 @@ the same tests also run under the normal integration pytest path: pytest -sv tests/integration/defs/stress_test/disagg_cancel/test_injector.py ``` -### Lifecycle smoke (injector not exercised on real workers) +### Manual regular guardrail run + +From a full TRT-LLM integration environment: + +```bash +cd /path/to/TensorRT-LLM/tests/integration/defs +export LLM_MODELS_ROOT=/path/to/model/root + +pytest stress_test/disagg_cancel/test_disagg_cancel_stress.py \ + --test-list=../test_lists/qa/llm_function_stress.txt \ + --output-dir=/tmp/trtllm-disagg-cancel-stress \ + -s -v +``` + +To collect without running: ```bash -pytest -sv tests/integration/defs/stress_test/disagg_cancel/test_disagg_cancel_stress.py::test_disagg_cancellation_marathon +pytest stress_test/disagg_cancel/test_disagg_cancel_stress.py \ + --test-list=../test_lists/qa/llm_function_stress.txt \ + --output-dir=/tmp/trtllm-disagg-cancel-stress \ + -s --co -q +``` + +### Manual CI trigger + +On a GitHub pull request, ask the CI bot which stress stages are +available, then trigger the QA stress stage that consumes +`tests/integration/test_lists/qa/llm_function_stress.txt`: + +```text +/bot help +/bot run --extra-stage "" ``` -`setup()` is still a stub, so this only checks harness lifecycle -(`setup` → `start` → `wait` → `stop`). The injector thread exits -immediately because no workers are registered via -`bind_tracked_workers()`. +The bot stage name is owned by CI/Jenkins configuration and is not +declared in this directory. -### Local marathon (after `setup()` lands) +### Manual full cancellation/poison run -Once `setup()` launches a real 3P3D cluster and registers workers, -the full 2-hour marathon runs via the same pytest entry point. For -development, set `duration_min: 10` and trim `injections:` in the -YAML. +After the runtime support is in place, switch the YAML to +`mode: full_cancel_poison`, set the intended duration and canary +references, then run the same pytest entry point. For development, +use a shorter `duration_min` and trim `injections:` locally before +restoring the checked-in values. ## Pass criteria -A marathon run is "clean" iff all of the following hold: +`log_only` is clean iff all of the following hold: + +- The 3P3D disaggregated cluster starts and reaches readiness. +- At least one normal completion probe succeeds through the + disaggregated front-end. +- No hard-zero log patterns for UAF, broken promise, or + segmentation-fault-class failures appear in any saved worker or + disagg-server log. + +`full_cancel_poison` is clean iff all of the following hold: - No hard-zero log patterns (e.g. `Cannot cancel request`, `Broken promise`, `unquiesced`, double-free / UAF traces) appear in any @@ -157,23 +267,22 @@ A marathon run is "clean" iff all of the following hold: - KV-cache utilization growth ≤ 10 percentage points end-to-end (leak guard). -Concrete thresholds for each metric are declared in the marathon -YAML's `pass_criteria:` block. +Concrete thresholds for each metric are declared in the marathon YAML. ## How to debug a failure -(Stub — the full debug guide lands together with the thread -implementations.) - -For now, when the skeleton test fails: +When the regular guardrail fails: 1. Confirm the YAML parses: ```bash python -c "from harness import StressConfig; StressConfig.from_yaml_path('configs/marathon_cpp_v1_deepseek.yaml')" ``` 2. Check the `failure_reason` field in `collect_results()` output. -3. Look at the pytest stdout for harness `logger` lines (each thread - logs its identity on entry / exit). +3. Inspect the log tails printed by `disagg_test_utils.terminate()` + during teardown; saved worker logs and `disagg_server.log` are + tailed before cleanup. +4. If setup times out, confirm `LLM_MODELS_ROOT`, GPU count, and + `trtllm-serve` availability in the integration environment. ## Cross-references diff --git a/tests/integration/defs/stress_test/disagg_cancel/configs/README.md b/tests/integration/defs/stress_test/disagg_cancel/configs/README.md index a95545b7284a..b29f02390005 100644 --- a/tests/integration/defs/stress_test/disagg_cancel/configs/README.md +++ b/tests/integration/defs/stress_test/disagg_cancel/configs/README.md @@ -25,6 +25,25 @@ The new `stress_config:` top-level block is consumed by `StressConfig` itself (dataclass field docstrings) and the example values in `marathon_cpp_v1_deepseek.yaml`. +## Harness modes + +`stress_config.mode` is the switch between the regular guardrail and +the full cancellation/poison marathon: + +- `log_only`: registered CI mode. The harness launches the real + disaggregated cluster, sends normal non-cancel probes, and scans + worker/server logs for UAF, broken-promise, and segmentation-fault + signatures. This mode is safe before in-flight cancellation and + poison-buffer hardening are available. +- `full_cancel_poison`: opt-in mode for the completed runtime. The + harness enables the cancellation load, SIGSTOP/SIGKILL injections, + token-equivalent canaries, metrics scraping, and KV-growth checks. + +When switching from `log_only` to `full_cancel_poison`, update +`duration_min`, canary references, poison-buffer log expectations, and +the test-list timeout together. The top-level README has the exact +checklist. + ## Backend-knob axis: KV-cache manager × transceiver runtime Two knobs select which (KV cache manager × transceiver runtime) @@ -51,7 +70,8 @@ Python changes** required beyond extending the parametrize list. To add a new config: 1. Copy `marathon_cpp_v1_deepseek.yaml` as a template. -2. Adjust `model`, `kv_cache_manager`, `transceiver`, and any +2. Choose `stress_config.mode`, then adjust `model`, + `kv_cache_manager`, `transceiver`, and any load-shape knobs (`base_concurrency`, `client_cancel_rate`, `output_length`, `injections:`, `pass_criteria:`). 3. Add the new filename to `_MARATHON_CONFIGS` in diff --git a/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_cpp_v1_deepseek.yaml b/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_cpp_v1_deepseek.yaml index 7764035ce05f..e1fb72872d09 100644 --- a/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_cpp_v1_deepseek.yaml +++ b/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_cpp_v1_deepseek.yaml @@ -52,13 +52,29 @@ generation_servers: # Schema documented in ../README.md. # ============================================================================ stress_config: - duration_min: 120 # 2 h marathon + # MODE SWITCH (intentionally loud): + # - log_only: current registered CI mode. Launches the real disagg cluster, + # sends normal non-cancel probes, and scans saved worker/server logs for + # UAF, broken-promise, and segmentation-fault signatures. This mode does + # NOT require in-flight cancellation or poison-buffer support. + # - full_cancel_poison: future opt-in mode after in-flight cancellation and + # poison-buffer hardening are available. It enables the cancellation load, + # fault injections, token-equivalent canaries, and KV-growth checks below. + mode: log_only + duration_min: 10 # regular guardrail run; use 120 for full_cancel_poison # Backend-knob axis selectors. Must match # context_servers / generation_servers above. kv_cache_manager: v1 # V1 (C++) KV cache manager transceiver: cpp # C++-backed transceiver (BindKvCacheTransceiver) + log_only_probe: + interval_s: 30 + prompt: "Write one sentence about reliable distributed inference." + max_tokens: 32 + seed: 42 + request_timeout_s: 30 + base_concurrency: 64 client_cancel_rate: 0.10 input_length: @@ -120,10 +136,13 @@ stress_config: log_scan: hard_zero_patterns: - "Broken promise" - - "NO RECOVERY" - "Segfault" + - "Segmentation fault" - "SIGSEGV" - "0xffffffffffffffff" - - "Poisoned .* cache transfer buffer" + - "use-after-free" + - "heap-use-after-free" + - "AddressSanitizer:.*use-after-free" + - "double[- ]free" kv_cache_growth_max: 0.10 # final utilization ≤ baseline + 10 percentage points diff --git a/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_python_v2_qwen.yaml b/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_python_v2_qwen.yaml index d74108335002..79b7a4b73b38 100644 --- a/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_python_v2_qwen.yaml +++ b/tests/integration/defs/stress_test/disagg_cancel/configs/marathon_python_v2_qwen.yaml @@ -54,6 +54,11 @@ generation_servers: kv_transfer_timeout_ms: 60000 stress_config: + # Placeholder template for the future full marathon. This YAML is + # intentionally not parametrized until the runtime supports the + # cancellation + poison-buffer contract and canary references are + # recorded for Qwen2.5-7B-Instruct. + mode: full_cancel_poison duration_min: 120 kv_cache_manager: v2 diff --git a/tests/integration/defs/stress_test/disagg_cancel/harness.py b/tests/integration/defs/stress_test/disagg_cancel/harness.py index d85b9d88b936..f2eb18cb5f7a 100644 --- a/tests/integration/defs/stress_test/disagg_cancel/harness.py +++ b/tests/integration/defs/stress_test/disagg_cancel/harness.py @@ -33,8 +33,10 @@ import os import random import re +import shutil import signal import subprocess +import tempfile import threading import time import urllib.error @@ -47,6 +49,10 @@ logger = logging.getLogger(__name__) +_STRESS_MODE_LOG_ONLY = "log_only" +_STRESS_MODE_FULL_CANCEL_POISON = "full_cancel_poison" +_STRESS_MODES = (_STRESS_MODE_LOG_ONLY, _STRESS_MODE_FULL_CANCEL_POISON) + # --------------------------------------------------------------------------- # Config dataclasses @@ -58,6 +64,7 @@ # simply aren't passed to the constructor, so the field defaults # apply automatically and are not duplicated here. _STRESS_CONFIG_COERCERS: dict[str, Callable[[Any], Any]] = { + "mode": str, "duration_min": float, "kv_cache_manager": str, "transceiver": str, @@ -76,6 +83,7 @@ class StressConfig: pass them around without re-parsing. """ + mode: str = _STRESS_MODE_LOG_ONLY duration_min: float = 120.0 kv_cache_manager: str = "v1" # v1 | v2 (v2 + CPP is invalid) transceiver: str = "cpp" # cpp | python @@ -137,6 +145,8 @@ def validate(self) -> None: supplied (the C++ transceiver only supports the V1 KV cache manager). """ + if self.mode not in _STRESS_MODES: + raise ValueError(f"mode must be one of {_STRESS_MODES}, got {self.mode!r}") if self.kv_cache_manager == "v2" and self.transceiver == "cpp": # The C++ transceiver (BindKvCacheTransceiver) only supports # the V1 KV cache manager. V2 must be paired with the Python @@ -152,6 +162,16 @@ def validate(self) -> None: if self.transceiver not in ("cpp", "python"): raise ValueError(f"transceiver must be 'cpp' or 'python', got {self.transceiver!r}") + @property + def is_log_only(self) -> bool: + """True when the harness should run the regular CI guardrail mode.""" + return self.mode == _STRESS_MODE_LOG_ONLY + + @property + def is_full_cancel_poison(self) -> bool: + """True when the harness should run the full cancellation/poison marathon.""" + return self.mode == _STRESS_MODE_FULL_CANCEL_POISON + _INJECTION_TARGET_RE = re.compile(r"^(ctx|gen)_worker_(\d+)$") @@ -778,7 +798,8 @@ def _load_iteration_shape(config: StressConfig, elapsed_s: float) -> dict[str, A if interval_s <= 0.0: raise ValueError( - f"stress_config.bursts.interval_min must be positive, got {bursts.get('interval_min')!r}" + "stress_config.bursts.interval_min must be positive, got " + f"{bursts.get('interval_min')!r}" ) if duration_s <= 0.0: raise ValueError( @@ -916,6 +937,7 @@ def __init__( self._cluster: Any = None # tuple returned by setup_disagg_cluster self._worker_specs: list[WorkerLaunchSpec] = [] self._tracked_workers: list[_TrackedWorker] = [] + self._server_log_path: Optional[str] = None self._marathon_start_monotonic: float = 0.0 # Disagg-server front-end the canary targets; populated by @@ -926,6 +948,7 @@ def __init__( # Thread handles (populated by start()). self._load_thread: Optional[threading.Thread] = None + self._log_only_thread: Optional[threading.Thread] = None self._canary_thread: Optional[threading.Thread] = None self._injector_thread: Optional[threading.Thread] = None self._log_scanner_thread: Optional[threading.Thread] = None @@ -944,13 +967,256 @@ def __init__( def setup(self) -> None: """Launch the disagg cluster from the YAML and record launch specs. - Stub: real implementation delegates to ``setup_disagg_cluster`` - in ``tests/integration/defs/disaggregated/test_disaggregated.py`` + Delegates the process launch to ``setup_disagg_cluster`` in + ``tests/integration/defs/disaggregated/test_disaggregated.py`` and shadow-tracks per-worker ``WorkerLaunchSpec`` so the injector thread can later relaunch a SIGKILLed worker without - modifying shared infrastructure. + modifying shared infrastructure. The harness-only + ``stress_config`` block is stripped from the temporary YAML + passed to the shared launcher so worker config validation only + sees normal ``trtllm-serve`` settings. """ - logger.info("[harness] setup() — stub: cluster not actually launched") + from test_disaggregated import ( + build_worker_config, + get_default_disagg_cluster_config, + get_ucx_tls, + setup_disagg_cluster, + ) + + cluster_config = self._load_sanitized_cluster_config() + raw_model_name = str(cluster_config.get("model") or "") + if not raw_model_name: + raise ValueError(f"YAML at {self.yaml_path} is missing top-level model") + model_name = self._resolve_model_name(raw_model_name) + + server_start_timeout_s = int(self.config.raw.get("server_start_timeout_s", 1200)) + run_env = os.environ.copy() + run_env["UCX_TLS"] = get_ucx_tls() + + setup_yaml_path = self._write_sanitized_cluster_yaml(cluster_config) + try: + self._cluster = setup_disagg_cluster( + setup_yaml_path, + model_name=model_name, + env=run_env, + server_start_timeout=server_start_timeout_s, + save_log=True, + ) + finally: + try: + os.unlink(setup_yaml_path) + except OSError: + logger.debug("[harness] could not unlink %s; ignoring", setup_yaml_path) + + config, ctx_workers, gen_workers, disagg_server, server_port, work_dir = self._cluster + server_host = config.get("hostname", "localhost") + server_url = f"http://{server_host}:{server_port}" + + disagg_cluster = get_default_disagg_cluster_config() + disagg_cluster["cluster_uri"] = server_url + ctx_servers = config.get("context_servers", {}) + gen_servers = config.get("generation_servers", {}) + disagg_cluster["minimal_instances"] = { + "context_servers": ctx_servers.get("num_instances", 1), + "generation_servers": gen_servers.get("num_instances", 1), + } + ctx_worker_config = build_worker_config(config, ctx_servers, disagg_cluster) + gen_worker_config = build_worker_config(config, gen_servers, disagg_cluster) + ctx_specs, gen_specs = self._build_worker_launch_specs( + ctx_workers=ctx_workers, + gen_workers=gen_workers, + ctx_worker_config=ctx_worker_config, + gen_worker_config=gen_worker_config, + ctx_servers=ctx_servers, + gen_servers=gen_servers, + model_name=model_name, + work_dir=work_dir, + env=run_env, + host=server_host, + ) + self._refresh_worker_ports_from_cluster_info(server_url, ctx_specs, gen_specs) + self.bind_tracked_workers(ctx_workers, gen_workers, ctx_specs, gen_specs) + self.bind_server_endpoint(server_url, model_name) + self._server_log_path = getattr(disagg_server, "log_path", None) + logger.info( + "[harness] setup() launched %d ctx worker(s), %d gen worker(s), server=%s", + len(ctx_workers), + len(gen_workers), + server_url, + ) + + def _load_sanitized_cluster_config(self) -> dict[str, Any]: + """Load YAML and remove harness-only fields before cluster launch.""" + with self.yaml_path.open("r", encoding="utf-8") as f: + doc = yaml.safe_load(f) + if not isinstance(doc, dict): + raise ValueError(f"YAML at {self.yaml_path} must be a mapping") + cluster_config = dict(doc) + cluster_config.pop("stress_config", None) + return cluster_config + + def _write_sanitized_cluster_yaml(self, cluster_config: dict[str, Any]) -> str: + """Write the launcher-facing YAML to a temporary file.""" + fd, path = tempfile.mkstemp(prefix="disagg_cancel_cluster_", suffix=".yaml") + with os.fdopen(fd, "w", encoding="utf-8") as f: + yaml.safe_dump(cluster_config, f) + return path + + def _resolve_model_name(self, model_name: str) -> str: + """Resolve relative model names against ``LLM_MODELS_ROOT`` when set.""" + path = Path(model_name).expanduser() + if path.is_absolute() or path.exists(): + return str(path) + models_root = os.environ.get("LLM_MODELS_ROOT") + if models_root: + return str(Path(models_root).expanduser() / model_name) + return model_name + + def _build_worker_launch_specs( + self, + *, + ctx_workers: list[Any], + gen_workers: list[Any], + ctx_worker_config: dict[str, Any], + gen_worker_config: dict[str, Any], + ctx_servers: dict[str, Any], + gen_servers: dict[str, Any], + model_name: str, + work_dir: str, + env: dict[str, str], + host: str, + ) -> tuple[list[WorkerLaunchSpec], list[WorkerLaunchSpec]]: + """Reconstruct worker launch metadata for log scanning and respawn.""" + import torch + + num_gpus = torch.cuda.device_count() + if num_gpus <= 0: + raise RuntimeError("setup_disagg_cluster returned, but torch reports no CUDA devices") + + gpus_per_ctx = ( + int(ctx_servers.get("tensor_parallel_size", 1)) + * int(ctx_servers.get("pipeline_parallel_size", 1)) + * int(ctx_servers.get("context_parallel_size", 1)) + ) + gpus_per_gen = ( + int(gen_servers.get("tensor_parallel_size", 1)) + * int(gen_servers.get("pipeline_parallel_size", 1)) + * int(gen_servers.get("context_parallel_size", 1)) + ) + + ctx_specs: list[WorkerLaunchSpec] = [] + gen_specs: list[WorkerLaunchSpec] = [] + next_device = 0 + for index, wrapper in enumerate(ctx_workers): + device = self._format_device_ids(next_device, gpus_per_ctx, num_gpus) + next_device += gpus_per_ctx + ctx_specs.append( + self._make_worker_launch_spec( + role="ctx", + index=index, + wrapper=wrapper, + worker_config=ctx_worker_config, + model_name=model_name, + work_dir=work_dir, + device=device, + env=env, + host=host, + ) + ) + for index, wrapper in enumerate(gen_workers): + device = self._format_device_ids(next_device, gpus_per_gen, num_gpus) + next_device += gpus_per_gen + gen_specs.append( + self._make_worker_launch_spec( + role="gen", + index=index, + wrapper=wrapper, + worker_config=gen_worker_config, + model_name=model_name, + work_dir=work_dir, + device=device, + env=env, + host=host, + ) + ) + return ctx_specs, gen_specs + + def _format_device_ids(self, first_device: int, count: int, num_gpus: int) -> str: + """Return the CUDA_VISIBLE_DEVICES string used by setup_disagg_cluster.""" + return ",".join( + str(d) for d in dict.fromkeys((first_device + j) % num_gpus for j in range(count)) + ) + + def _make_worker_launch_spec( + self, + *, + role: str, + index: int, + wrapper: Any, + worker_config: dict[str, Any], + model_name: str, + work_dir: str, + device: str, + env: dict[str, str], + host: str, + ) -> WorkerLaunchSpec: + """Create one shadow launch spec from the shared ProcessWrapper.""" + return WorkerLaunchSpec( + role=role, + index=index, + model_name=model_name, + worker_config=worker_config, + work_dir=work_dir, + port=int(getattr(wrapper, "port", 0) or 0), + device=device, + env=env.copy(), + log_path=getattr(wrapper, "log_path", None), + host=host, + ) + + def _refresh_worker_ports_from_cluster_info( + self, + server_url: str, + ctx_specs: list[WorkerLaunchSpec], + gen_specs: list[WorkerLaunchSpec], + ) -> None: + """Populate worker host/port from disagg ``/cluster_info`` when available.""" + try: + with urllib.request.urlopen(f"{server_url}/cluster_info", timeout=5.0) as response: + info = json.loads(response.read().decode("utf-8", errors="replace")) + except (json.JSONDecodeError, TimeoutError, OSError, urllib.error.URLError) as exc: + logger.warning("[harness] could not read cluster_info for worker ports: %s", exc) + return + + current_workers = info.get("current_workers") or {} + for specs, key in ( + (ctx_specs, "context_servers"), + (gen_specs, "generation_servers"), + ): + workers = current_workers.get(key) or [] + if len(workers) != len(specs): + logger.warning( + "[harness] cluster_info %s count mismatch: %d worker(s), %d spec(s)", + key, + len(workers), + len(specs), + ) + for spec, worker_info in zip(specs, workers): + if not isinstance(worker_info, dict): + continue + host = worker_info.get("host") + port = worker_info.get("port") + if isinstance(host, str) and host: + spec.host = host + try: + spec.port = int(port) + except (TypeError, ValueError): + logger.warning( + "[harness] cluster_info %s worker %d has invalid port %r", + key, + spec.index, + port, + ) def bind_tracked_workers( self, @@ -988,30 +1254,47 @@ def bind_server_endpoint(self, server_url: str, model_name: str) -> None: self._model_name = model_name def start(self) -> None: - """Spawn the five worker threads. Returns immediately. + """Spawn the mode-specific worker threads. Returns immediately. + + ``log_only`` mode runs the regular CI guardrail: a normal + non-cancel probe loop plus log-pattern fail-fast. It + intentionally avoids the cancellation load, fault injector, + poison canary, and KV-growth gates until those runtime fixes + are present. - If ``setup()`` has not bound a live server endpoint yet, the - load thread warns and signals ``stop_event`` so the lifecycle - smoke still completes cleanly without waiting out the - ``wait_until_done`` timeout. + ``full_cancel_poison`` mode runs all five full-stress threads. """ self._marathon_start_monotonic = time.monotonic() - logger.info("[harness] start() — spawning worker threads") - self._load_thread = threading.Thread( - target=self._load_thread_body, name="stress-load", daemon=True - ) - self._canary_thread = threading.Thread( - target=self._canary_thread_body, name="stress-canary", daemon=True - ) - self._injector_thread = threading.Thread( - target=self._injector_thread_body, name="stress-injector", daemon=True - ) - self._log_scanner_thread = threading.Thread( - target=self._log_scanner_thread_body, name="stress-log-scanner", daemon=True - ) - self._metrics_thread = threading.Thread( - target=self._metrics_thread_body, name="stress-metrics", daemon=True - ) + logger.info("[harness] start() — mode=%s", self.config.mode) + if self.config.is_log_only: + self._log_only_thread = threading.Thread( + target=self._log_only_thread_body, + name="stress-log-only-probe", + daemon=True, + ) + self._log_scanner_thread = threading.Thread( + target=self._log_scanner_thread_body, + name="stress-log-scanner", + daemon=True, + ) + elif self.config.is_full_cancel_poison: + self._load_thread = threading.Thread( + target=self._load_thread_body, name="stress-load", daemon=True + ) + self._canary_thread = threading.Thread( + target=self._canary_thread_body, name="stress-canary", daemon=True + ) + self._injector_thread = threading.Thread( + target=self._injector_thread_body, name="stress-injector", daemon=True + ) + self._log_scanner_thread = threading.Thread( + target=self._log_scanner_thread_body, + name="stress-log-scanner", + daemon=True, + ) + self._metrics_thread = threading.Thread( + target=self._metrics_thread_body, name="stress-metrics", daemon=True + ) for t in self._all_threads(): t.start() @@ -1143,9 +1426,126 @@ def collect_results(self) -> dict[str, Any]: } # ------------------------------------------------------------------ - # Thread bodies (stubs — implemented incrementally) + # Thread bodies # ------------------------------------------------------------------ + def _configured_duration_s(self) -> float: + """Return the active run duration, honoring unit-test overrides.""" + if self._load_duration_s is not None: + return self._load_duration_s + return float(self.config.duration_min) * 60.0 + + def _log_only_thread_body(self) -> None: + """Run regular CI protection without cancellation or poison gates. + + This mode still launches the real disaggregated cluster and + sends normal completion probes through the front-end. It fails + the test on probe errors and runs concurrently with + ``log_scanner_thread`` so UAF, broken-promise, and segfault + signatures in worker/server logs remain hard-zero failures. + """ + if not self._server_url: + self.mark_failed("log_only mode requires setup() to bind a server endpoint") + self.stop_event.set() + return + + duration_s = self._configured_duration_s() + if duration_s <= 0.0: + logger.info("[log_only] non-positive duration %.3fs; exiting", duration_s) + self.stop_event.set() + return + + probe_cfg = self.config.raw.get("log_only_probe") or {} + try: + interval_s = float(probe_cfg.get("interval_s", 30.0)) + max_tokens = int(probe_cfg.get("max_tokens", 32)) + seed = int(probe_cfg.get("seed", 42)) + timeout_s = float(probe_cfg.get("request_timeout_s", self._canary_request_timeout_s)) + prompt = str( + probe_cfg.get( + "prompt", + "Write one sentence about reliable distributed inference.", + ) + ) + except (TypeError, ValueError) as exc: + self.mark_failed(f"log_only_probe config error: {exc}") + self.stop_event.set() + return + if interval_s <= 0.0: + self.mark_failed(f"log_only_probe.interval_s must be positive, got {interval_s}") + self.stop_event.set() + return + + deadline = time.monotonic() + duration_s + logger.info( + "[log_only] probing %s every %.1fs for %.1fs", + self._server_url, + interval_s, + duration_s, + ) + + while ( + time.monotonic() < deadline + and not self.stop_event.is_set() + and not self.failed_event.is_set() + ): + send_start = time.monotonic() + token_ids, _, err = self._send_log_only_probe( + prompt=prompt, + max_tokens=max_tokens, + seed=seed, + timeout_s=timeout_s, + ) + success = err is None + self._canary_records.append( + { + "timestamp": time.time(), + "elapsed_s": time.monotonic() - self._marathon_start_monotonic, + "mode": _STRESS_MODE_LOG_ONLY, + "prompt_index": 0, + "success": success, + "token_equivalent": None, + "latency_s": time.monotonic() - send_start, + "error": err, + "token_count": len(token_ids or []), + } + ) + if not success: + self.mark_failed(f"log_only probe failed: {err}") + break + + remaining = min(interval_s, max(0.0, deadline - time.monotonic())) + if remaining > 0.0: + self.stop_event.wait(timeout=remaining) + + if not self.failed_event.is_set() and not any( + record.get("success") for record in self._canary_records + ): + self.mark_failed("log_only mode completed without a successful probe") + if not self.failed_event.is_set(): + logger.info("[log_only] completed; signalling stop_event") + self.stop_event.set() + + def _send_log_only_probe( + self, + *, + prompt: str, + max_tokens: int, + seed: int, + timeout_s: float, + ) -> tuple[Optional[list[int]], Optional[str], Optional[str]]: + """Send one normal completion request for ``log_only`` mode.""" + if self._server_url is None: + return None, None, "missing_server_url" + return _send_canary_request( + self._server_url, + self._model_name or "log-only-probe", + prompt, + max_tokens, + seed, + timeout_s, + ) + def _load_thread_body(self) -> None: """Wrap ``run_cancel_stress_test`` in a duration-bounded loop. @@ -1164,11 +1564,7 @@ def _load_thread_body(self) -> None: self.stop_event.set() return - duration_s = ( - self._load_duration_s - if self._load_duration_s is not None - else float(self.config.duration_min) * 60.0 - ) + duration_s = self._configured_duration_s() if duration_s <= 0.0: logger.info("[load_thread] non-positive duration %.3fs; exiting", duration_s) self.stop_event.set() @@ -1570,6 +1966,19 @@ def _log_scanner_thread_body(self) -> None: return sources: list[_LogSource] = [] + if self._server_log_path is not None: + server_spec = WorkerLaunchSpec( + role="server", + index=0, + model_name=self._model_name or "disagg-server", + worker_config={}, + work_dir="", + port=0, + device="", + env={}, + log_path=self._server_log_path, + ) + sources.append(_LogSource(spec=server_spec, path=Path(self._server_log_path))) for spec in self._worker_specs: if spec.log_path is None: logger.warning( @@ -1589,7 +1998,7 @@ def _log_scanner_thread_body(self) -> None: return logger.info( - "[log_scanner] tailing %d worker log(s) against %d hard_zero pattern(s)", + "[log_scanner] tailing %d log source(s) against %d hard_zero pattern(s)", len(sources), len(patterns), ) @@ -1682,6 +2091,7 @@ def _all_threads(self) -> list[threading.Thread]: t for t in ( self._load_thread, + self._log_only_thread, self._canary_thread, self._injector_thread, self._log_scanner_thread, @@ -1691,10 +2101,17 @@ def _all_threads(self) -> list[threading.Thread]: ] def _teardown_cluster(self) -> None: - """Best-effort cluster shutdown via ``terminate()``. - - Stub: no-op since ``setup()`` doesn't actually launch yet. - """ + """Best-effort cluster shutdown via ``terminate()``.""" if self._cluster is None: return - logger.info("[harness] _teardown_cluster — stub") + from disagg_test_utils import terminate + + config, ctx_workers, gen_workers, disagg_server, _server_port, work_dir = self._cluster + del config + logger.info("[harness] tearing down disagg cluster work_dir=%s", work_dir) + try: + terminate(*ctx_workers, *gen_workers, disagg_server) + finally: + shutil.rmtree(work_dir, ignore_errors=True) + self._cluster = None + self._server_log_path = None diff --git a/tests/integration/defs/stress_test/disagg_cancel/test_disagg_cancel_stress.py b/tests/integration/defs/stress_test/disagg_cancel/test_disagg_cancel_stress.py index 36d5b607d3bb..aa5c74b51779 100644 --- a/tests/integration/defs/stress_test/disagg_cancel/test_disagg_cancel_stress.py +++ b/tests/integration/defs/stress_test/disagg_cancel/test_disagg_cancel_stress.py @@ -25,7 +25,11 @@ from __future__ import annotations +import textwrap +import threading +import time from pathlib import Path +from typing import Any import pytest @@ -57,29 +61,7 @@ def test_all_marathon_yamls_parse_and_validate() -> None: @pytest.mark.parametrize("config_filename", _MARATHON_CONFIGS) def test_disagg_cancellation_marathon(config_filename: str) -> None: - """Drive a long-running disagg cancellation marathon and assert pass criteria. - - Current scope: only what the already-implemented thread bodies - can contribute. The marathon entry point exists; the marathon - *content* lands incrementally as setup / pass-criteria wiring is - completed: - - - lifecycle plumbing (setup -> start -> wait -> stop -> - collect_results, fail-fast event propagation, dict-shape - contract). - - log-pattern fail-fast — a hard-zero pattern in any worker log - trips ``failure_reason`` via the log_scanner thread - (component-level coverage in ``test_log_scanner.py``). - - Marathon pass criteria not yet enforced here (will land alongside - their owning result aggregation in follow-up changes): canary error - rate, recovery time after each injection, KV-cache utilization - growth bound, injection-schedule completeness, sustained load - throughput. Until those land, this test passes trivially after - the lifecycle smoke completes; the value at this stage is that - the entry point and result-dict contract are pinned down so the - follow-up commits can extend in place rather than restructure. - """ + """Drive the configured disagg stress mode and assert current pass criteria.""" config_path = _CONFIG_DIR / config_filename assert config_path.exists(), ( f"Marathon config not found: {config_path}. " @@ -90,13 +72,8 @@ def test_disagg_cancellation_marathon(config_filename: str) -> None: try: harness.setup() harness.start() - # setup() is still a stub, so no server endpoint is bound. - # The load thread exits and signals ``stop_event`` on that - # no-endpoint path, which lets this lifecycle smoke complete - # almost instantly. Once setup launches a real cluster, the - # timeout becomes ``stress_config.duration_min`` plus a safety - # margin. - clean = harness.wait_until_done(timeout_s=10.0) + timeout_s = float(harness.config.duration_min) * 60.0 + 300.0 + clean = harness.wait_until_done(timeout_s=timeout_s) assert clean is True, ( f"wait_until_done did not return cleanly; failure_reason={harness.failure_reason!r}" ) @@ -112,5 +89,109 @@ def test_disagg_cancellation_marathon(config_filename: str) -> None: assert "kv_utilization_samples" in results assert "injection_events" in results assert results["failure_reason"] is None, ( - f"Harness tripped fail-fast in skeleton run: {results['failure_reason']!r}" + f"Harness tripped fail-fast: {results['failure_reason']!r}" + ) + if harness.config.is_log_only: + assert any( + record.get("mode") == "log_only" and record.get("success") + for record in results["canary_records"] + ), "log_only mode completed without a successful server probe" + + +def _write_mode_yaml(tmp_path: Path, stress_config: str) -> Path: + """Write a minimal marathon YAML for mode-level harness tests.""" + yaml_path = tmp_path / "mode.yaml" + content = textwrap.dedent( + """\ + hostname: localhost + model: dummy + backend: pytorch + context_servers: {} + generation_servers: {} + stress_config: + """ + ) + content += textwrap.indent(textwrap.dedent(stress_config).strip(), " ") + "\n" + yaml_path.write_text(content) + return yaml_path + + +@pytest.mark.parametrize("mode", ["log_only", "full_cancel_poison"]) +def test_stress_config_accepts_supported_modes(tmp_path: Path, mode: str) -> None: + """Both supported mode strings should parse and expose helper predicates.""" + cfg = StressConfig.from_yaml_path( + _write_mode_yaml( + tmp_path, + f"""\ + mode: {mode} + duration_min: 1 + kv_cache_manager: v1 + transceiver: cpp + """, + ) + ) + + assert cfg.mode == mode + assert cfg.is_log_only is (mode == "log_only") + assert cfg.is_full_cancel_poison is (mode == "full_cancel_poison") + + +def test_stress_config_rejects_unknown_mode(tmp_path: Path) -> None: + """Typos in mode must fail during YAML validation.""" + with pytest.raises(ValueError, match="mode must be one of"): + StressConfig.from_yaml_path( + _write_mode_yaml( + tmp_path, + """\ + mode: accidental + duration_min: 1 + kv_cache_manager: v1 + transceiver: cpp + """, + ) + ) + + +def test_log_only_thread_sends_probe_and_stops( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """The regular-protection mode should require at least one clean probe.""" + h = DisaggCancellationStressHarness( + _write_mode_yaml( + tmp_path, + """\ + mode: log_only + duration_min: 1 + kv_cache_manager: v1 + transceiver: cpp + log_only_probe: + interval_s: 0.01 + max_tokens: 8 + request_timeout_s: 1 + log_scan: + hard_zero_patterns: + - "Broken promise" + """, + ), + load_duration_s=0.03, ) + h.bind_server_endpoint("http://127.0.0.1:8000", "test-model") + h._marathon_start_monotonic = time.monotonic() + + calls: list[dict[str, Any]] = [] + + def fake_probe(**kwargs: Any) -> tuple[list[int], None, None]: + calls.append(kwargs) + return [1, 2], None, None + + monkeypatch.setattr(h, "_send_log_only_probe", fake_probe) + + thread = threading.Thread(target=h._log_only_thread_body, name="test-log-only", daemon=True) + thread.start() + thread.join(timeout=2.0) + + assert not thread.is_alive() + assert h.stop_event.is_set() + assert not h.failed_event.is_set() + assert calls + assert any(record["success"] for record in h._canary_records) diff --git a/tests/integration/test_lists/qa/llm_function_stress.txt b/tests/integration/test_lists/qa/llm_function_stress.txt index 1bf7c4f3f77b..9d31b6149c46 100644 --- a/tests/integration/test_lists/qa/llm_function_stress.txt +++ b/tests/integration/test_lists/qa/llm_function_stress.txt @@ -5,6 +5,7 @@ disaggregated/test_disaggregated.py::test_disaggregated_stress_test[input8k-outp disaggregated/test_disaggregated.py::test_disaggregated_stress_test[input8k-output1k-conc512-gpt_oss_120b_eagle_trtllm_stress] disaggregated/test_disaggregated.py::test_disaggregated_stress_test[input8k-output1k-conc512-gpt_oss_120b_triton_stress] disaggregated/test_disaggregated.py::test_disaggregated_stress_test[input8k-output1k-conc512-qwen3_5_4b_fp8_stress] +stress_test/disagg_cancel/test_disagg_cancel_stress.py::test_disagg_cancellation_marathon[marathon_cpp_v1_deepseek.yaml] TIMEOUT (45) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1LongBenchV2::test_fp8_8gpus accuracy/test_llm_api_pytorch.py::TestDeepSeekR1LongBenchV2::test_nvfp4_4gpus accuracy/test_llm_api_pytorch.py::TestKimiK2::test_nvfp4_longseq_trtllm_moe_stress From 03ed843cd76420985329d5689deef5fb6c928b1d Mon Sep 17 00:00:00 2001 From: jthomson04 Date: Wed, 10 Jun 2026 10:09:52 -0700 Subject: [PATCH 6/6] [None][feat] Preserve cache_salt string in KV cache events (#13051) Signed-off-by: jthomson04 --- benchmarks/cpp/disaggServerBenchmark.cpp | 4 +- benchmarks/cpp/gptManagerBenchmark.cpp | 4 +- .../tensorrt_llm/batch_manager/blockKey.h | 12 +-- .../batch_manager/kvCacheManager.h | 1 - .../tensorrt_llm/batch_manager/llmRequest.h | 34 ++++---- cpp/include/tensorrt_llm/executor/executor.h | 15 ++-- cpp/include/tensorrt_llm/executor/types.h | 1 - cpp/include/tensorrt_llm/runtime/common.h | 1 - cpp/tensorrt_llm/batch_manager/blockKey.cpp | 10 +-- .../batch_manager/kvCacheEventManager.cpp | 3 +- .../batch_manager/kvCacheManager.cpp | 6 +- cpp/tensorrt_llm/executor/request.cpp | 14 +-- cpp/tensorrt_llm/executor/requestImpl.h | 31 +++++-- cpp/tensorrt_llm/executor/serialization.cpp | 17 ++-- .../nanobind/batch_manager/bindings.cpp | 20 ++--- .../nanobind/batch_manager/llmRequest.cpp | 4 +- .../nanobind/batch_manager/llmRequest.h | 3 +- .../nanobind/executor/bindings.cpp | 1 + .../nanobind/executor/request.cpp | 16 ++-- .../batch_manager/agentTreeTest.cpp | 2 +- .../batch_manager/kvCacheManagerTest.cpp | 52 ++++++------ .../executor/serializeUtilsTest.cpp | 2 +- .../cpp/executor/executorExampleKvEvents.cpp | 12 +-- examples/llm-api/llm_kv_cache_connector.py | 11 ++- .../connectors/kv_cache_connector.py | 6 +- tensorrt_llm/_torch/pyexecutor/llm_request.py | 2 +- .../_torch/pyexecutor/resource_manager.py | 31 ++++--- tensorrt_llm/_utils.py | 2 + tensorrt_llm/executor/base_worker.py | 2 +- tensorrt_llm/executor/executor.py | 9 +- tensorrt_llm/executor/request.py | 20 ++++- tensorrt_llm/inputs/__init__.py | 4 +- tensorrt_llm/inputs/utils.py | 14 +-- tensorrt_llm/llmapi/llm.py | 6 +- .../llmapi/test_llm_kv_cache_events.py | 85 ++++++++++++++++++- 35 files changed, 287 insertions(+), 170 deletions(-) diff --git a/benchmarks/cpp/disaggServerBenchmark.cpp b/benchmarks/cpp/disaggServerBenchmark.cpp index 057e7898afeb..bc3a7a2659fd 100644 --- a/benchmarks/cpp/disaggServerBenchmark.cpp +++ b/benchmarks/cpp/disaggServerBenchmark.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -543,7 +543,7 @@ texec::Request makeExecutorContextRequest(Sample const& sample, SizeType32 const std::nullopt, // logitsPostProcessorName std::nullopt, // logitsPostProcessor encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt, - std::nullopt); // cacheSaltID + std::nullopt); // cacheSalt request.setRequestType(tensorrt_llm::executor::RequestType::REQUEST_TYPE_CONTEXT_ONLY); return request; } diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp index b4f0948c1155..287cbba343ce 100644 --- a/benchmarks/cpp/gptManagerBenchmark.cpp +++ b/benchmarks/cpp/gptManagerBenchmark.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -838,7 +838,7 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW std::nullopt, // logitsPostProcessorName std::nullopt, // logitsPostProcessor encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt, - std::nullopt); // cacheSaltID + std::nullopt); // cacheSalt } void benchmarkExecutor(std::optional const& decoderEngineDir, diff --git a/cpp/include/tensorrt_llm/batch_manager/blockKey.h b/cpp/include/tensorrt_llm/batch_manager/blockKey.h index 002b4356c869..920212845331 100644 --- a/cpp/include/tensorrt_llm/batch_manager/blockKey.h +++ b/cpp/include/tensorrt_llm/batch_manager/blockKey.h @@ -29,7 +29,6 @@ using VecTokens = std::vector; using UniqueToken = tensorrt_llm::runtime::UniqueToken; using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens; using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType; -using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType; using MmKey = tensorrt_llm::executor::MmKey; //! \brief Generate the multimodal extra keys for a single KV cache block. @@ -49,7 +48,8 @@ struct BlockKey // Extra keys for multimodal data (similar to VLLM's approach) // Each extra key is a pair of (mm_hash, start_offset_in_block) std::vector extraKeys; - std::optional cacheSaltID = std::nullopt; + // Cache salt string. Used as part of the block key so blocks from different salts do not match. + std::optional cacheSalt = std::nullopt; BlockKey() = default; @@ -64,12 +64,12 @@ struct BlockKey } explicit BlockKey(bool usesExtraIds, std::optional loraTaskId, VecUniqueTokens uniqueTokens, - std::vector extraKeys = {}, std::optional cacheSaltID = std::nullopt) + std::vector extraKeys = {}, std::optional cacheSalt = std::nullopt) : usesExtraIds{usesExtraIds} , loraTaskId{loraTaskId} , uniqueTokens{std::move(uniqueTokens)} , extraKeys{std::move(extraKeys)} - , cacheSaltID{cacheSaltID} + , cacheSalt{std::move(cacheSalt)} { } @@ -86,7 +86,7 @@ struct BlockKey } //! \brief Count the number of leading tokens that match between this key and \p other. - //! \details Returns 0 immediately when loraTaskId, extraKeys, or cacheSaltID differ, because those fields must + //! \details Returns 0 immediately when loraTaskId, extraKeys, or cacheSalt differ, because those fields must //! match exactly before token content is considered. //! \param other The key to compare against. //! \return Number of leading uniqueTokens that are identical in both keys. @@ -94,7 +94,7 @@ struct BlockKey { SizeType32 numMatched{0}; if (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId && extraKeys == other.extraKeys - && cacheSaltID == other.cacheSaltID) + && cacheSalt == other.cacheSalt) { auto [matchEnd, otherMatchEnd] = std::mismatch( uniqueTokens.begin(), uniqueTokens.end(), other.uniqueTokens.begin(), other.uniqueTokens.end()); diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index d3966adf2f20..c665f7a8df95 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -82,7 +82,6 @@ using UniqueToken = tensorrt_llm::runtime::UniqueToken; using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens; using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType; using BlocksPerWindow = std::map>; -using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType; using MmKey = tensorrt_llm::executor::MmKey; using WindowSizeType = SizeType32; diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 263a15b50970..886147a09c73 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -108,7 +108,6 @@ class GenericLlmRequest using MillisecondsType = std::chrono::milliseconds; using TimePoint = std::chrono::time_point; using Duration = std::chrono::time_point::duration; - using CacheSaltIDType = runtime::CacheSaltIDType; GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr const& inputTokens, runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional endId = std::nullopt, @@ -147,11 +146,12 @@ class GenericLlmRequest std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt, + std::optional arrivalTime = std::nullopt, std::optional>> agent_hierarchy = std::nullopt, std::optional>> multimodalItemRunCuOffsets = std::nullopt, std::optional>> multimodalRunPositions = std::nullopt, - std::optional>> multimodalRunLengths = std::nullopt) + std::optional>> multimodalRunLengths = std::nullopt, + std::optional cacheSalt = std::nullopt) : mRequestId(requestId) , mPromptLen(inputTokens->size()) , mMaxNewTokens(maxNewTokens) @@ -213,7 +213,7 @@ class GenericLlmRequest , mGuidedDecodingParams(std::move(guidedDecodingParams)) , mLanguageAdapterUid(languageAdapterUid) , mAllottedTimeMs(allottedTimeMs) - , mCacheSaltID(cacheSaltID) + , mCacheSalt(std::move(cacheSalt)) , mAgentHierarchy(std::move(agent_hierarchy)) { if (mEncoderTokens.has_value() || encoderInputFeatures.has_value()) @@ -242,7 +242,7 @@ class GenericLlmRequest executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1, std::optional languageAdapterUid = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt) + std::optional cacheSalt = std::nullopt) : mRequestId(requestId) , mPromptLen(inputTokens.size()) , mMaxNewTokens(maxNewTokens) @@ -283,7 +283,7 @@ class GenericLlmRequest , mContextPhaseParams(contextPhaseParams) , mNumReturnSequences(numReturnSequences) , mLanguageAdapterUid(languageAdapterUid) - , mCacheSaltID(cacheSaltID) + , mCacheSalt(std::move(cacheSalt)) { if (mEncoderTokens.has_value()) { @@ -323,7 +323,7 @@ class GenericLlmRequest , mGuidedDecodingParams(req.getGuidedDecodingParams()) , mLanguageAdapterUid(req.getLanguageAdapterUid()) , mAllottedTimeMs(req.getAllottedTimeMs()) - , mCacheSaltID(req.getCacheSaltID()) + , mCacheSalt(req.getCacheSalt()) { if (req.getRequestType() == executor::RequestType::REQUEST_TYPE_GENERATION_ONLY) { @@ -1897,9 +1897,9 @@ class GenericLlmRequest return mLanguageAdapterUid; } - [[nodiscard]] std::optional getCacheSaltID() const + [[nodiscard]] std::optional getCacheSalt() const { - return mCacheSaltID; + return mCacheSalt; } std::vector getLanguageAdapterRouting( @@ -2196,8 +2196,8 @@ class GenericLlmRequest bool mUseDraftModel{false}; - // Cache salt id for each request. - std::optional mCacheSaltID{std::nullopt}; + // Cache salt string. Used in BlockKey hashing/matching and surfaced in KV cache events. + std::optional mCacheSalt{std::nullopt}; std::optional>> mAgentHierarchy{std::nullopt}; @@ -2394,11 +2394,12 @@ class LlmRequest : public GenericLlmRequest std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt, + std::optional arrivalTime = std::nullopt, std::optional>> agent_hierarchy = std::nullopt, std::optional> multimodalItemRunCuOffsets = std::nullopt, std::optional> multimodalRunPositions = std::nullopt, - std::optional> multimodalRunLengths = std::nullopt) + std::optional> multimodalRunLengths = std::nullopt, + std::optional cacheSalt = std::nullopt) : Base(requestId, maxNewTokens, std::make_shared>(std::move(inputTokens)), samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), @@ -2431,8 +2432,8 @@ class LlmRequest : public GenericLlmRequest inputTokenExtraIds ? std::make_optional(std::make_shared(std::move(*inputTokenExtraIds))) : std::optional>(std::nullopt), numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics, - std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID, - arrivalTime, std::move(agent_hierarchy), + std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, arrivalTime, + std::move(agent_hierarchy), multimodalItemRunCuOffsets.has_value() ? std::make_shared>(std::move(multimodalItemRunCuOffsets.value())) : std::optional>>(std::nullopt), @@ -2441,7 +2442,8 @@ class LlmRequest : public GenericLlmRequest : std::optional>>(std::nullopt), multimodalRunLengths.has_value() ? std::make_shared>(std::move(multimodalRunLengths.value())) - : std::optional>>(std::nullopt)) + : std::optional>>(std::nullopt), + std::move(cacheSalt)) { } diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 1f625d57084c..f716bef6e3cc 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -714,8 +714,9 @@ class Request /// @param allottedTimeMs The allotted time in milliseconds after which the request is cancelled with a timedOut /// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism /// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled. - /// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string. /// @param disaggRequestId Disaggregated request ID. + /// @param cacheSalt Optional cache salt string. If provided, KV cache blocks are tagged so reuse is limited to + /// requests with the same salt. The string is also surfaced in KV cache events. Defaults to std::nullopt. Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false, SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(), std::optional const& endId = std::nullopt, std::optional const& padId = std::nullopt, @@ -743,8 +744,7 @@ class Request std::optional guidedDecodingParams = std::nullopt, std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, - std::optional cacheSaltID = std::nullopt, - std::optional disaggRequestId = std::nullopt); + std::optional disaggRequestId = std::nullopt, std::optional cacheSalt = std::nullopt); /// @brief This logits postprocessor name will dispatch to the batched logits postprocessor static auto constexpr kBatchedPostProcessorName = "batched"; @@ -792,7 +792,7 @@ class Request [[nodiscard]] std::optional getGuidedDecodingParams() const; [[nodiscard]] std::optional getLanguageAdapterUid() const; [[nodiscard]] std::optional getAllottedTimeMs() const; - [[nodiscard]] std::optional getCacheSaltID() const; + [[nodiscard]] std::optional getCacheSalt() const; [[nodiscard]] std::optional> getAdditionalOutputNames() const; [[nodiscard]] std::optional getDisaggRequestId() const; @@ -829,7 +829,7 @@ class Request void setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams); void setLanguageAdapterUid(SizeType32 languageAdapterUid); void setAllottedTimeMs(MillisecondsType allottedTimeMs); - void setCacheSaltID(CacheSaltIDType cacheSaltID); + void setCacheSalt(std::optional cacheSalt); void setDisaggRequestId(IdType disaggRequestId); private: @@ -1729,13 +1729,14 @@ struct KVCacheStoredBlockData KVCacheStoredBlockData(IdType blockHash, tensorrt_llm::runtime::VecUniqueTokens tokens, std::optional loraId, SizeType32 cacheLevel, SizeType32 priority, - std::vector mmKeys = {}) + std::vector mmKeys = {}, std::optional cacheSalt = std::nullopt) : blockHash{blockHash} , tokens{std::move(tokens)} , loraId{loraId} , cacheLevel{cacheLevel} , priority{priority} , mmKeys{std::move(mmKeys)} + , cacheSalt{std::move(cacheSalt)} { } @@ -1751,6 +1752,8 @@ struct KVCacheStoredBlockData SizeType32 priority; /// @brief The multimodal keys of the block std::vector mmKeys; + /// @brief The original cache salt string of the block, if any + std::optional cacheSalt; }; struct KVCacheStoredData diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h index 0800865df7f1..2e6051291629 100644 --- a/cpp/include/tensorrt_llm/executor/types.h +++ b/cpp/include/tensorrt_llm/executor/types.h @@ -59,7 +59,6 @@ using RandomSeedType = std::uint64_t; using VecLogProbs = std::vector; using StreamPtr = std::shared_ptr; using MillisecondsType = std::chrono::milliseconds; -using CacheSaltIDType = std::uint64_t; using LogitsPostProcessor = std::function)>; using LogitsPostProcessorMap = std::unordered_map; diff --git a/cpp/include/tensorrt_llm/runtime/common.h b/cpp/include/tensorrt_llm/runtime/common.h index 7a3079d0bd75..2cda8821c133 100644 --- a/cpp/include/tensorrt_llm/runtime/common.h +++ b/cpp/include/tensorrt_llm/runtime/common.h @@ -44,7 +44,6 @@ using TokenIdType = std::int32_t; using LoraTaskIdType = std::uint64_t; using TokenExtraIdType = std::uint64_t; using VecTokenExtraIds = std::vector; -using CacheSaltIDType = std::uint64_t; struct UniqueToken { diff --git a/cpp/tensorrt_llm/batch_manager/blockKey.cpp b/cpp/tensorrt_llm/batch_manager/blockKey.cpp index 33092a5a37fa..e8125b0106f4 100644 --- a/cpp/tensorrt_llm/batch_manager/blockKey.cpp +++ b/cpp/tensorrt_llm/batch_manager/blockKey.cpp @@ -334,7 +334,7 @@ std::vector buildBlockKeys( currentTokenIdx += uniqueTokens.size(); blockKeys.emplace_back(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(), - std::move(uniqueTokens), std::move(extraKeys), llmRequest.getCacheSaltID()); + std::move(uniqueTokens), std::move(extraKeys), llmRequest.getCacheSalt()); } return blockKeys; } @@ -342,7 +342,7 @@ std::vector buildBlockKeys( bool BlockKey::operator==(BlockKey const& other) const noexcept { return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId && uniqueTokens == other.uniqueTokens - && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID); + && extraKeys == other.extraKeys && cacheSalt == other.cacheSalt); } BlockKey BlockKey::shorten(int newNumTokens) const @@ -364,10 +364,10 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no // Constants provide very good distribution - each input bit affects each output bit with ~50% probability. size_t seed = blockKey.uniqueTokens.size() ^ parentHash * UINT64_C(0xbf58476d1ce4e5b9); - if (parentHash == 0 && blockKey.cacheSaltID) + if (parentHash == 0 && blockKey.cacheSalt) { - // Only hashing the cache salt ID for the first block in the sequence - uint64_t c = blockKey.cacheSaltID.value(); + // Only mix the cache salt into the hash for the first block in the sequence. + uint64_t c = static_cast(std::hash{}(blockKey.cacheSalt.value())); seed = hash64Mix(c, seed); } diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp index c8f6ddd474f4..2a986adf310f 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp @@ -105,7 +105,8 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector const& blocks for (auto const& block : blocks) { data.blocks.emplace_back(block->getHash(), block->getUniqueTokens(), block->getBlockKey().loraTaskId, - block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority(), block->getExtraKeys()); + block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority(), block->getExtraKeys(), + block->getBlockKey().cacheSalt); } enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank}); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 1a74166e5235..d1112439686e 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2076,7 +2076,7 @@ std::shared_ptr WindowBlockManager::findBlocksInReuseTreeByBlockKe for (auto const& blockedUniqueTokensList : blockedUniqueTokens) { blockKeys.emplace_back(blockKey.usesExtraIds, blockKey.loraTaskId, blockedUniqueTokensList, blockKey.extraKeys, - blockKey.cacheSaltID); + blockKey.cacheSalt); } return searchReuseTree(blockKeys); } @@ -4460,7 +4460,7 @@ std::vector KVCacheManager::commitAndGetBlockHashesForRequest( bool const usesExtraIds = llmRequest.getInputTokensExtraIds().has_value(); auto const loraTaskId = llmRequest.getLoraTaskId(); - auto const cacheSaltID = llmRequest.getCacheSaltID(); + auto const cacheSalt = llmRequest.getCacheSalt(); std::vector hashes; hashes.reserve(static_cast(limit)); @@ -4476,7 +4476,7 @@ std::vector KVCacheManager::commitAndGetBlockHashesForRequest( SizeType32 const tokenEnd = tokenStart + tokensPerBlock; auto extraKeys = generateBlockHashExtraKeys(llmRequest, tokenStart, tokenEnd); VecUniqueTokens blockTokens(uniqueTokens.begin() + tokenStart, uniqueTokens.begin() + tokenEnd); - BlockKey blockKey(usesExtraIds, loraTaskId, std::move(blockTokens), std::move(extraKeys), cacheSaltID); + BlockKey blockKey(usesExtraIds, loraTaskId, std::move(blockTokens), std::move(extraKeys), cacheSalt); block->setBlockKey(blockKey, /*isFull=*/true); // setHash() chains through mPrevBlockInSeq, which was wired in addBlockToBeam. The // loop walks blocks in allocation order, so by the time we reach block b its diff --git a/cpp/tensorrt_llm/executor/request.cpp b/cpp/tensorrt_llm/executor/request.cpp index 5ac62d3fcb64..e32045892ba7 100644 --- a/cpp/tensorrt_llm/executor/request.cpp +++ b/cpp/tensorrt_llm/executor/request.cpp @@ -40,8 +40,8 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, std::optional encoderOutputLength, std::optional crossAttentionMask, SizeType32 numReturnSequences, std::optional eagleConfig, std::optional skipCrossAttnBlocks, std::optional guidedDecodingParams, std::optional languageAdapterUid, - std::optional allottedTimeMs, std::optional cacheSaltID, - std::optional disaggRequestId) + std::optional allottedTimeMs, std::optional disaggRequestId, + std::optional cacheSalt) : mImpl(std::make_unique(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId, padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias), std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput), @@ -50,7 +50,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, type, std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask, numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid, - allottedTimeMs, cacheSaltID, disaggRequestId)) + allottedTimeMs, disaggRequestId, std::move(cacheSalt))) { } @@ -249,9 +249,9 @@ std::optional Request::getLanguageAdapterUid() const return mImpl->getLanguageAdapterUid(); } -std::optional Request::getCacheSaltID() const +std::optional Request::getCacheSalt() const { - return mImpl->getCacheSaltID(); + return mImpl->getCacheSalt(); } std::optional Request::getDisaggRequestId() const @@ -424,9 +424,9 @@ void Request::setLanguageAdapterUid(SizeType32 languageAdapterUid) mImpl->setLanguageAdapterUid(languageAdapterUid); } -void Request::setCacheSaltID(CacheSaltIDType cacheSaltID) +void Request::setCacheSalt(std::optional cacheSalt) { - mImpl->setCacheSaltID(cacheSaltID); + mImpl->setCacheSalt(std::move(cacheSalt)); } void Request::setDisaggRequestId(IdType disaggRequestId) diff --git a/cpp/tensorrt_llm/executor/requestImpl.h b/cpp/tensorrt_llm/executor/requestImpl.h index 281f81d462a7..55610885b1ae 100644 --- a/cpp/tensorrt_llm/executor/requestImpl.h +++ b/cpp/tensorrt_llm/executor/requestImpl.h @@ -32,6 +32,21 @@ class Request::Impl { public: + //! Maximum allowed length of a cache salt string. Cache salts are copied into every BlockKey and emitted + //! with KV cache events, so unbounded strings would inflate memory and serialization cost proportional to + //! the number of blocks. + static constexpr std::size_t kMaxCacheSaltLength{256}; + + static std::optional validateCacheSalt(std::optional cacheSalt) + { + if (cacheSalt.has_value() && cacheSalt->size() > kMaxCacheSaltLength) + { + TLLM_THROW("cacheSalt length (%zu) exceeds the maximum supported length (%zu).", cacheSalt->size(), + kMaxCacheSaltLength); + } + return cacheSalt; + } + Impl(VecTokens inputTokenIds, SizeType32 maxNewTokens, bool streaming, SamplingConfig const& samplingConfig, OutputConfig outputConfig, std::optional const& endId, std::optional const& padId, std::optional> positionIds, std::optional> badWords, @@ -48,7 +63,7 @@ class Request::Impl std::optional crossAttentionMask, SizeType32 numReturnSequences, std::optional eagleConfig, std::optional skipCrossAttnBlocks, std::optional guidedDecodingParams, std::optional languageAdapterUid, std::optional allottedTimeMs, - std::optional cacheSaltID, std::optional disaggRequestId) + std::optional disaggRequestId, std::optional cacheSalt = std::nullopt) : mInputTokenIds(std::move(inputTokenIds)) , mMaxNewTokens(maxNewTokens) , mStreaming(streaming) @@ -85,7 +100,7 @@ class Request::Impl , mGuidedDecodingParams(std::move(guidedDecodingParams)) , mLanguageAdapterUid(languageAdapterUid) , mAllottedTimeMs(allottedTimeMs) - , mCacheSaltID(cacheSaltID) + , mCacheSalt(validateCacheSalt(std::move(cacheSalt))) , mDisaggRequestId(disaggRequestId) { validate(); @@ -298,9 +313,9 @@ class Request::Impl return mLanguageAdapterUid; } - [[nodiscard]] std::optional getCacheSaltID() const + [[nodiscard]] std::optional getCacheSalt() const { - return mCacheSaltID; + return mCacheSalt; } [[nodiscard]] std::optional getDisaggRequestId() const @@ -482,9 +497,9 @@ class Request::Impl mLanguageAdapterUid = languageAdapterUid; } - void setCacheSaltID(CacheSaltIDType cacheSaltID) + void setCacheSalt(std::optional cacheSalt) { - mCacheSaltID = cacheSaltID; + mCacheSalt = validateCacheSalt(std::move(cacheSalt)); } void setDisaggRequestId(IdType disaggRequestId) @@ -565,8 +580,8 @@ class Request::Impl lambda(mGuidedDecodingParams); lambda(mLanguageAdapterUid); lambda(mAllottedTimeMs ? std::make_optional(mAllottedTimeMs->count()) : std::nullopt); - lambda(mCacheSaltID); lambda(mDisaggRequestId); + lambda(mCacheSalt); } VecTokens mInputTokenIds; @@ -605,7 +620,7 @@ class Request::Impl std::optional mGuidedDecodingParams; std::optional mLanguageAdapterUid; std::optional mAllottedTimeMs; - std::optional mCacheSaltID; + std::optional mCacheSalt; std::optional mDisaggRequestId; }; diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index e4c325423126..ce081e10c603 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -885,8 +885,8 @@ Request Serialization::deserializeRequest(std::istream& is) auto allottedTimeMs = allottedTimeInt ? std::optional(std::chrono::milliseconds(*allottedTimeInt)) : std::nullopt; - auto cacheSaltID = su::deserialize>(is); auto disaggRequestId = su::deserialize>(is); + auto cacheSalt = su::deserialize>(is); return Request(std::move(inputTokenIds), maxNewTokens, streaming, samplingConfig, outputConfig, endId, padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias), @@ -896,7 +896,7 @@ Request Serialization::deserializeRequest(std::istream& is) std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, requestType, std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, std::move(crossAttentionMask), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks), - std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, cacheSaltID, disaggRequestId); + std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, disaggRequestId, std::move(cacheSalt)); } void Serialization::serialize(Request const& request, std::ostream& os) @@ -2517,6 +2517,7 @@ size_t Serialization::serializedSize(KVCacheStoredBlockData const& data) totalSize += su::serializedSize(data.cacheLevel); totalSize += su::serializedSize(data.priority); totalSize += su::serializedSize(data.mmKeys); + totalSize += su::serializedSize(data.cacheSalt); return totalSize; } @@ -2528,6 +2529,7 @@ void Serialization::serialize(KVCacheStoredBlockData const& data, std::ostream& su::serialize(data.cacheLevel, os); su::serialize(data.priority, os); su::serialize(data.mmKeys, os); + su::serialize(data.cacheSalt, os); } KVCacheStoredBlockData Serialization::deserializeKVCacheStoredBlockData(std::istream& is) @@ -2538,8 +2540,9 @@ KVCacheStoredBlockData Serialization::deserializeKVCacheStoredBlockData(std::ist auto cacheLevel = su::deserialize(is); auto priority = su::deserialize(is); auto mmKeys = su::deserialize>(is); + auto cacheSalt = su::deserialize>(is); - return KVCacheStoredBlockData{blockHash, tokens, loraId, cacheLevel, priority, mmKeys}; + return KVCacheStoredBlockData{blockHash, tokens, loraId, cacheLevel, priority, mmKeys, cacheSalt}; } // KVcacheRemovedData @@ -2686,7 +2689,7 @@ size_t Serialization::serializedSize(tensorrt_llm::batch_manager::kv_cache_manag totalSize += su::serializedSize(key.uniqueTokens); // std::vector where MmKey is pair, SizeType32> totalSize += su::serializedSize(key.extraKeys); - totalSize += su::serializedSize(key.cacheSaltID); + totalSize += su::serializedSize(key.cacheSalt); return totalSize; } @@ -2696,7 +2699,7 @@ void Serialization::serialize(tensorrt_llm::batch_manager::kv_cache_manager::Blo su::serialize(key.loraTaskId, os); su::serialize(key.uniqueTokens, os); su::serialize(key.extraKeys, os); - su::serialize(key.cacheSaltID, os); + su::serialize(key.cacheSalt, os); } tensorrt_llm::batch_manager::kv_cache_manager::BlockKey Serialization::deserializeBlockKey(std::istream& is) @@ -2705,13 +2708,13 @@ tensorrt_llm::batch_manager::kv_cache_manager::BlockKey Serialization::deseriali auto loraTaskId = su::deserialize>(is); auto uniqueTokens = su::deserialize>(is); auto extraKeys = su::deserialize>(is); - auto cacheSaltID = su::deserialize>(is); + auto cacheSalt = su::deserialize>(is); tensorrt_llm::batch_manager::kv_cache_manager::BlockKey key; key.usesExtraIds = usesExtraIds; key.loraTaskId = std::move(loraTaskId); key.uniqueTokens = std::move(uniqueTokens); key.extraKeys = std::move(extraKeys); - key.cacheSaltID = std::move(cacheSaltID); + key.cacheSalt = std::move(cacheSalt); return key; } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index 5447f13c60e2..086cdf7547c2 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -202,7 +202,7 @@ void initBindings(nb::module_& m) .def_prop_ro("llm_request_type", &GenLlmReq::getLlmRequestType) .def_prop_ro("parent_request_id", &GenLlmReq::getParentRequestId) .def_prop_ro("is_child", &GenLlmReq::isChild) - .def_prop_ro("cache_salt_id", &GenLlmReq::getCacheSaltID) + .def_prop_ro("cache_salt", &GenLlmReq::getCacheSalt) .def_prop_ro("kv_cache_retention_config", &GenLlmReq::getKvCacheRetentionConfig) .def_prop_ro("multimodal_hashes", [](GenLlmReq& self) @@ -347,12 +347,12 @@ void initBindings(nb::module_& m) std::optional language_adapter_uid, std::optional allotted_time_ms, std::optional context_phase_params, - std::optional cache_salt_id, std::optional arrival_time, std::optional>> agent_hierarchy, std::optional> multimodal_item_run_cu_offsets, std::optional> multimodal_run_positions, - std::optional> multimodal_run_lengths) + std::optional> multimodal_run_lengths, + std::optional cache_salt) { auto makeOptionalTensor = [](std::optional const& atTensor, bool unsqueeze = false) { @@ -392,9 +392,9 @@ void initBindings(nb::module_& m) encoder_input_tokens, return_encoder_output, client_id, priority, encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, - guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, - arrival_time, std::move(agent_hierarchy), multimodal_item_run_cu_offsets, multimodal_run_positions, - multimodal_run_lengths}; + guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params, arrival_time, + std::move(agent_hierarchy), multimodal_item_run_cu_offsets, multimodal_run_positions, + multimodal_run_lengths, std::move(cache_salt)}; }, nb::arg("request_id"), nb::arg("max_new_tokens"), nb::arg("input_tokens"), nb::arg("sampling_config"), nb::arg("is_streaming"), nb::arg("end_id") = std::nullopt, nb::arg("pad_id") = std::nullopt, @@ -420,10 +420,10 @@ void initBindings(nb::module_& m) nb::arg("eagle_config") = std::nullopt, nb::arg("skip_cross_attn_blocks") = std::nullopt, nb::arg("return_perf_metrics") = false, nb::arg("guided_decoding_params") = std::nullopt, nb::arg("language_adapter_uid") = std::nullopt, nb::arg("allotted_time_ms") = std::nullopt, - nb::arg("context_phase_params") = std::nullopt, nb::arg("cache_salt_id") = std::nullopt, - nb::arg("arrival_time") = std::nullopt, nb::arg("agent_hierarchy") = std::nullopt, - nb::arg("multimodal_item_run_cu_offsets") = std::nullopt, - nb::arg("multimodal_run_positions") = std::nullopt, nb::arg("multimodal_run_lengths") = std::nullopt) + nb::arg("context_phase_params") = std::nullopt, nb::arg("arrival_time") = std::nullopt, + nb::arg("agent_hierarchy") = std::nullopt, nb::arg("multimodal_item_run_cu_offsets") = std::nullopt, + nb::arg("multimodal_run_positions") = std::nullopt, nb::arg("multimodal_run_lengths") = std::nullopt, + nb::arg("cache_salt") = std::nullopt) .def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, nb::arg("vocab_size")) .def(nb::init()) .def("validate", &tb::LlmRequest::validate, nb::arg("max_input_len"), nb::arg("max_seq_len"), diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp index 796909bd419a..21f9ea39823b 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp @@ -127,11 +127,11 @@ std::shared_ptr LlmRequest::toTrtLlm() const mLanguageAdapterUid, // mAllottedTimeMs, // mContextPhaseParams, // - mCacheSaltID, // mPerfMetrics.timingMetrics.arrivalTime, // mAgentHierarchy, // mMultimodalItemRunCuOffsets, // mMultimodalRunPositions, // - mMultimodalRunLengths // + mMultimodalRunLengths, // + mCacheSalt // ); } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h index 967870c8177c..387d915ab4aa 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h @@ -86,7 +86,7 @@ class LlmRequest : public tb::GenericLlmRequest std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt, + std::optional arrivalTime = std::nullopt, std::optional>> agent_hierarchy = std::nullopt, std::optional> multimodalItemRunCuOffsets = std::nullopt, std::optional> multimodalRunPositions = std::nullopt, @@ -155,7 +155,6 @@ class LlmRequest : public tb::GenericLlmRequest languageAdapterUid, // allottedTimeMs, // contextPhaseParams, // - cacheSaltID, // arrivalTime, // std::move(agent_hierarchy), // multimodalItemRunCuOffsets.has_value() diff --git a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp index fbec513de3a1..b0ad31b7347e 100644 --- a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp @@ -230,6 +230,7 @@ void initBindings(nb::module_& m) .def_ro("lora_id", &tle::KVCacheStoredBlockData::loraId) .def_ro("cache_level", &tle::KVCacheStoredBlockData::cacheLevel) .def_ro("priority", &tle::KVCacheStoredBlockData::priority) + .def_ro("cache_salt", &tle::KVCacheStoredBlockData::cacheSalt) .def_prop_ro("mm_keys", [](tle::KVCacheStoredBlockData const& self) { diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index 6cffe7740c13..b502370504d1 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -613,7 +613,7 @@ void initRequestBindings(nb::module_& m) self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(), self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(), self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(), - self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getDisaggRequestId()); + self.getGuidedDecodingParams(), self.getDisaggRequestId(), self.getCacheSalt()); }; auto requestSetstate = [](tle::Request& self, nb::tuple const& state) { @@ -642,7 +642,7 @@ void initRequestBindings(nb::module_& m) nb::cast>(state[29]), 1, nb::cast>(state[30]), nb::cast>(state[31]), nb::cast>(state[32]), std::nullopt, std::nullopt, - nb::cast>(state[33]), nb::cast>(state[34])); + nb::cast>(state[33]), nb::cast>(state[34])); }; nb::class_ request(m, "Request", nb::dynamic_attr()); @@ -683,8 +683,8 @@ void initRequestBindings(nb::module_& m) std::optional, // guidedDecodingParams std::optional, // languageAdapterUid std::optional, // allottedTimeMs - std::optional, // cacheSaltID - std::optional // disaggRequestId + std::optional, // disaggRequestId + std::optional // cacheSalt >(), // clang-format off nb::arg("input_token_ids"), @@ -724,9 +724,9 @@ void initRequestBindings(nb::module_& m) nb::arg("guided_decoding_params") = nb::none(), nb::arg("language_adapter_uid") = nb::none(), nb::arg("allotted_time_ms") = nb::none(), - nb::arg("cache_salt_id") = nb::none(), - nb::arg("disagg_request_id") = nb::none() - ) // clang-format on + nb::arg("disagg_request_id") = nb::none(), + nb::arg("cache_salt") = nb::none() + ) // clang-format on .def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds) .def_prop_ro("max_tokens", &tle::Request::getMaxTokens) .def_prop_rw("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) @@ -768,7 +768,7 @@ void initRequestBindings(nb::module_& m) .def_prop_rw( "guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams) .def_prop_rw("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs) - .def_prop_rw("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID) + .def_prop_rw("cache_salt", &tle::Request::getCacheSalt, &tle::Request::setCacheSalt) .def_prop_rw("context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams) .def_prop_rw("disagg_request_id", &tle::Request::getDisaggRequestId, &tle::Request::setDisaggRequestId) .def_prop_rw("priority", &tle::Request::getPriority, &tle::Request::setPriority) diff --git a/cpp/tests/unit_tests/batch_manager/agentTreeTest.cpp b/cpp/tests/unit_tests/batch_manager/agentTreeTest.cpp index 32e0b3e0ea17..5104f2a88a71 100644 --- a/cpp/tests/unit_tests/batch_manager/agentTreeTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/agentTreeTest.cpp @@ -64,7 +64,7 @@ class AgentTreeTest : public ::testing::Test std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, tensorrt_llm::executor::Request::kDefaultPriority, std::nullopt, std::nullopt, std::nullopt, tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, 1, std::nullopt, std::nullopt, - false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, agentHierarchy); + false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, agentHierarchy); } LlmRequestPtr createAgentDeepResearchRequest(SizeType32 nodeId, SizeType32 requestId) diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index bebc6dd920ac..bf1d7589ce31 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -2094,12 +2094,11 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); } -TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) +TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltTest) { - // Test that cache_salt_id prevents KV cache reuse between requests with same tokens - // but different cache_salt_id values. + // Test that cache_salt prevents KV cache reuse between requests with same tokens + // but different cache_salt values. using VecTokenExtraIds = LlmRequest::VecTokenExtraIds; - using CacheSaltIDType = LlmRequest::CacheSaltIDType; auto constexpr numLayers = 12; auto constexpr numKvHeads = 6; @@ -2135,7 +2134,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) auto const inputLength = static_cast(inputTokens->size()); /////////////////////////////////////////////////////////////////////////// - // Test Case 1: Request without cache_salt_id + // Test Case 1: Request without cache_salt LlmRequest::RequestIdType requestId{0}; auto llmRequest0 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, @@ -2143,8 +2142,8 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, - std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt); // No cache_salt_id + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt); // No cache_salt GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; @@ -2177,21 +2176,22 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); /////////////////////////////////////////////////////////////////////////// - // Test Case 2: Request with same tokens but with cache_salt_id = 12345 + // Test Case 2: Request with same tokens but with cache_salt = "tenant-A" requestId = 1; - CacheSaltIDType cacheSaltId1{12345}; + std::string const cacheSalt1{"tenant-A"}; auto llmRequest1 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, - std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - cacheSaltId1); // With cache_salt_id = 12345 + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, + cacheSalt1); // With cache_salt = "tenant-A" GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; - // Should NOT reuse blocks despite same tokens, because cache_salt_id is different + // Should NOT reuse blocks despite same tokens, because cache_salt is different auto promptLen1 = llmRequest1->getNumTokens(beamIdx); auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); auto prepopulatedPromptLen1 = blockManager @@ -2215,7 +2215,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); /////////////////////////////////////////////////////////////////////////// - // Test Case 3: Request with same tokens and same cache_salt_id = 12345 + // Test Case 3: Request with same tokens and same cache_salt = "tenant-A" requestId = 2; auto llmRequest2 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, @@ -2223,12 +2223,13 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, - std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - cacheSaltId1); // Same cache_salt_id = 12345 + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, + cacheSalt1); // Same cache_salt = "tenant-A" GenerationRequest seq2{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; - // SHOULD reuse blocks because both tokens and cache_salt_id match + // SHOULD reuse blocks because both tokens and cache_salt match auto promptLen2 = llmRequest2->getNumTokens(beamIdx); auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); auto prepopulatedPromptLen2 = blockManager @@ -2252,21 +2253,22 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); /////////////////////////////////////////////////////////////////////////// - // Test Case 4: Request with same tokens but different cache_salt_id = 67890 + // Test Case 4: Request with same tokens but different cache_salt = "tenant-B" requestId = 3; - CacheSaltIDType cacheSaltId2{67890}; + std::string const cacheSalt2{"tenant-B"}; auto llmRequest3 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, - std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - cacheSaltId2); // Different cache_salt_id = 67890 + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, + cacheSalt2); // Different cache_salt = "tenant-B" GenerationRequest seq3{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; - // Should NOT reuse blocks from any previous request because cache_salt_id is different + // Should NOT reuse blocks from any previous request because cache_salt is different auto promptLen3 = llmRequest3->getNumTokens(beamIdx); auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); auto prepopulatedPromptLen3 = blockManager @@ -2284,7 +2286,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); /////////////////////////////////////////////////////////////////////////// - // Test Case 5: Request without cache_salt_id again + // Test Case 5: Request without cache_salt again requestId = 4; auto llmRequest4 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, @@ -2292,12 +2294,12 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, - std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, - std::nullopt); // No cache_salt_id + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt); // No cache_salt GenerationRequest seq4{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; - // Should reuse blocks from request0 (blocks 0,1) because both have no cache_salt_id + // Should reuse blocks from request0 (blocks 0,1) because both have no cache_salt auto promptLen4 = llmRequest4->getNumTokens(beamIdx); auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); auto prepopulatedPromptLen4 = blockManager diff --git a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp index d80d0be456b4..98569bfc2aa0 100644 --- a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp +++ b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp @@ -1097,7 +1097,7 @@ TEST(SerializeUtilsTest, BlockKeyWithExtras) VecUniqueTokens uniqueTokens{UniqueToken{10, 100}, UniqueToken{20, 200}}; std::optional loraTaskId = LoraTaskIdType{42}; - // Note: cacheSaltID is intentionally not set since it is not serialized + // Note: cacheSalt is intentionally not set; round-tripping with it set is covered separately. BlockKey key(true, loraTaskId, uniqueTokens, extraKeys); testSerializeDeserialize(key); diff --git a/examples/cpp/executor/executorExampleKvEvents.cpp b/examples/cpp/executor/executorExampleKvEvents.cpp index a48cbdfa9769..ea1923294382 100644 --- a/examples/cpp/executor/executorExampleKvEvents.cpp +++ b/examples/cpp/executor/executorExampleKvEvents.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -50,13 +50,14 @@ struct RuntimeOptions struct KVCacheBlock { KVCacheBlock(size_t hash, int cacheLevel, int priority, std::optional loraId = std::nullopt, - std::shared_ptr prevBlock = nullptr); + std::shared_ptr prevBlock = nullptr, std::optional cacheSalt = std::nullopt); size_t hash; int cacheLevel; int priority; std::optional loraId; + std::optional cacheSalt; std::shared_ptr prevBlock; std::unordered_map> nextBlocks; @@ -196,12 +197,13 @@ RuntimeOptions parseArgs(int argc, char* argv[]) return runtimeOpts; } -KVCacheBlock::KVCacheBlock( - size_t hash, int cacheLevel, int priority, std::optional loraId, std::shared_ptr prevBlock) +KVCacheBlock::KVCacheBlock(size_t hash, int cacheLevel, int priority, std::optional loraId, + std::shared_ptr prevBlock, std::optional cacheSalt) : hash{hash} , cacheLevel{cacheLevel} , priority{priority} , loraId{loraId} + , cacheSalt{std::move(cacheSalt)} , prevBlock{prevBlock} , nextBlocks{} { @@ -255,7 +257,7 @@ void RadixTree::pollEvents() TLLM_CHECK(block.tokens.size() > 0); auto thisBlock = std::make_shared( - block.blockHash, block.cacheLevel, block.priority, block.loraId, prevBlock); + block.blockHash, block.cacheLevel, block.priority, block.loraId, prevBlock, block.cacheSalt); blockTable[block.blockHash] = thisBlock; // Link the parent to the new block diff --git a/examples/llm-api/llm_kv_cache_connector.py b/examples/llm-api/llm_kv_cache_connector.py index 0e6aa3d83aa9..882478993e5a 100644 --- a/examples/llm-api/llm_kv_cache_connector.py +++ b/examples/llm-api/llm_kv_cache_connector.py @@ -192,7 +192,7 @@ def build_connector_meta(self, scheduler_output: SchedulerOutput): len(block_ids)): if len(chunks[block_pos]) == self.block_size: hashed_tokens = self._hash_tokens(chunks[block_pos], - req.cache_salt_id) + req.cache_salt) file_path = self._file_path(hashed_tokens) @@ -202,11 +202,10 @@ def build_connector_meta(self, scheduler_output: SchedulerOutput): return metadata - def _hash_tokens(self, tokens: list[int], - cache_salt_id: Optional[int]) -> int: - # cache_salt_id must participate in the hash so that requests carrying + def _hash_tokens(self, tokens: list[int], cache_salt: Optional[str]) -> int: + # cache_salt must participate in the hash so that requests carrying # different salts (or no salt) cannot collide on the same cache file. - return abs(hash((cache_salt_id, tuple(tokens)))) + return abs(hash((cache_salt, tuple(tokens)))) def _file_path(self, hash_value: int) -> Path: return Path(self.cache_folder) / f"{hash_value}.pt" @@ -238,7 +237,7 @@ def get_num_new_matched_tokens( for chunk in remaining_chunks: # Only do full blocks. if len(chunk) == self.block_size: - hashed_tokens = self._hash_tokens(chunk, request.cache_salt_id) + hashed_tokens = self._hash_tokens(chunk, request.cache_salt) file_path = self._file_path(hashed_tokens) diff --git a/tensorrt_llm/_torch/pyexecutor/connectors/kv_cache_connector.py b/tensorrt_llm/_torch/pyexecutor/connectors/kv_cache_connector.py index f5034256e142..99da2a42265c 100644 --- a/tensorrt_llm/_torch/pyexecutor/connectors/kv_cache_connector.py +++ b/tensorrt_llm/_torch/pyexecutor/connectors/kv_cache_connector.py @@ -82,9 +82,9 @@ class RequestData: # Per-request cache salt that the KV cache manager uses to isolate reuse # between requests carrying different salts. Connectors that key cached # content on token sequences (e.g. by hashing tokens to a file path or - # remote object id) MUST mix cache_salt_id into their identifiers, + # remote object id) MUST mix cache_salt into their identifiers, # otherwise blocks from a different salt could be incorrectly reused. - cache_salt_id: Optional[int] = None + cache_salt: Optional[str] = None # A class to store some basic data regarding all inflight requests. @@ -361,7 +361,7 @@ def update_and_build_data(self, req: LlmRequest, kv_cache_manager: "KVCacheManag num_scheduled_tokens, block_hashes=block_hashes, priorities=priorities, - cache_salt_id=req.cache_salt_id, + cache_salt=req.cache_salt, ) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 21bdd6c0b3e4..a7d6614d6f4f 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -1099,7 +1099,7 @@ def executor_request_to_llm_request( priority=executor_request.priority, llm_request_type=llm_request_type, context_phase_params=executor_request.context_phase_params, - cache_salt_id=executor_request.cache_salt_id, + cache_salt=executor_request.cache_salt, arrival_time=getattr(executor_request, "py_arrival_time", None), py_multimodal_data=getattr(executor_request, "py_multimodal_data", None), diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index d8e948d15689..c67bec2f22b9 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -3,6 +3,7 @@ import copy import enum +import hashlib import math import os from abc import ABC, abstractmethod @@ -2996,11 +2997,10 @@ def prepare_context(self, req: LlmRequest) -> bool: all_tokens, req, end=len(all_tokens) - 1) else: tokens = None - kv_cache = self._create_kv_cache( - req.py_request_id, - req.lora_task_id, - tokens, - cache_salt_id=req.cache_salt_id) + kv_cache = self._create_kv_cache(req.py_request_id, + req.lora_task_id, + tokens, + cache_salt=req.cache_salt) if kv_cache is None: return False kv_cache.cuda_stream = self._stream.cuda_stream @@ -3123,11 +3123,10 @@ def _prepare_draft_resources(self, scheduled_batch: ScheduledRequests): for req in scheduled_batch.context_requests: kv_cache = self.kv_cache_map.get(req.py_request_id) if kv_cache is None: - kv_cache = self._create_kv_cache( - req.py_request_id, - req.lora_task_id, - None, - cache_salt_id=req.cache_salt_id) + kv_cache = self._create_kv_cache(req.py_request_id, + req.lora_task_id, + None, + cache_salt=req.cache_salt) kv_cache.stop_committing() if not self._resume_and_restore(req.py_request_id, kv_cache): raise RuntimeError( @@ -3295,7 +3294,7 @@ def release_resources(current_request: LlmRequest, if prepare_resource: # Dummy/warmup request. ``stop_committing()`` below blocks all # writes to the radix tree, so the choice of branch does not - # affect committed state. ``cache_salt_id`` is left defaulted + # affect committed state. ``cache_salt`` is left defaulted # to None to avoid coupling synthetic data to any salted branch. kv_cache = self._create_kv_cache(req.py_request_id, req.lora_task_id, input_tokens) @@ -3674,7 +3673,7 @@ def _create_kv_cache(self, request_id: int, lora_task_id: int | None, input_tokens: Sequence[TokenIdExt] | None, - cache_salt_id: int | None = None): + cache_salt: str | None = None): assert request_id not in self.kv_cache_map, f"KV cache for request {request_id} already exists" if self.index_mapper.num_free_slots() == 0: logger.warning( @@ -3683,8 +3682,14 @@ def _create_kv_cache(self, "Skipping KV cache creation; request will retry next iteration.", request_id, self.index_mapper.size(), self.index_mapper.size()) return None + # ReuseScope.salt is int|None; derive a deterministic int from the + # cache_salt string so the same string yields the same reuse namespace + # across processes (matches C++ blockKey hashing on cacheSalt). + salt_int = (int.from_bytes( + hashlib.sha256(cache_salt.encode("utf-8")).digest()[:8], "little") + if cache_salt is not None else None) kv_cache = self.impl.create_kv_cache( - ReuseScope(lora_id=lora_task_id, salt=cache_salt_id), + ReuseScope(lora_id=lora_task_id, salt=salt_int), input_tokens, ) self.kv_cache_map[request_id] = kv_cache diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 841f396104f5..955a0ab55082 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -1212,6 +1212,8 @@ def _stored_block_to_json(data): for token in data.tokens ], # "lora_id": data.lora_id, # TODO (shreyasm): enable serialization of lora_id + "cache_salt": + data.cache_salt, "cache_level": data.cache_level, "priority": diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index e42f2b25b24c..090742b38d19 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -594,8 +594,8 @@ def _deduce_max_tokens(request: GenerationRequest, kv_cache_retention_config=request.kv_cache_retention_config, context_phase_params=context_phase_params, type=request_type, - cache_salt_id=request.cache_salt_id, disagg_request_id=disagg_request_id, + cache_salt=request.cache_salt, priority=request.priority) executor_request.py_original_end_id = request.sampling_params.end_id executor_request.py_num_logprobs = request.sampling_params.logprobs diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index a5f5efea6427..5ea904531f22 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -134,7 +134,7 @@ def generate_async( postproc_params: Optional[PostprocParams] = None, multimodal_params: Optional[MultimodalParams] = None, scheduling_params: Optional[SchedulingParams] = None, - cache_salt_id: Optional[int] = None, + cache_salt: Optional[str] = None, arrival_time: Optional[float] = None, priority: float = DEFAULT_REQUEST_PRIORITY, ) -> GenerationResult: @@ -162,7 +162,7 @@ def generate_async( trace_headers=trace_headers, multimodal_params=multimodal_params, scheduling_params=scheduling_params, - cache_salt_id=cache_salt_id, + cache_salt=cache_salt, arrival_time=arrival_time, priority=priority) result = self.submit(request) @@ -180,6 +180,7 @@ def generate( prompt_adapter_request: Optional[Union[ PromptAdapterRequest, List[PromptAdapterRequest]]] = None, disaggregated_params: Optional[DisaggregatedParams] = None, + cache_salt: Optional[Union[str, List[Optional[str]]]] = None, ) -> Union[GenerationResult, List[GenerationResult]]: """Generate output for the given prompt token ids in the synchronous mode. Synchronous generation accepts either single prompt or batched prompts. @@ -205,6 +206,7 @@ def generate( pa_req = prompt_adapter_request[i] else: pa_req = prompt_adapter_request + cs = cache_salt[i] if isinstance(cache_salt, list) else cache_salt future = self.generate_async( p, sampling_params=sp, @@ -212,7 +214,8 @@ def generate( lora_request=lora_req, prompt_adapter_request=pa_req, streaming=False, - disaggregated_params=disaggregated_params) + disaggregated_params=disaggregated_params, + cache_salt=cs) futures.append(future) for future in futures: diff --git a/tensorrt_llm/executor/request.py b/tensorrt_llm/executor/request.py index adbc2358b372..43ea11706e54 100644 --- a/tensorrt_llm/executor/request.py +++ b/tensorrt_llm/executor/request.py @@ -89,6 +89,8 @@ def local_path(self): class GenerationRequest: + # Mirrors C++ Request::Impl::kMaxCacheSaltLength + MAX_CACHE_SALT_LEN: int = 256 def __init__( self, @@ -105,7 +107,7 @@ def __init__( postproc_params: Optional[PostprocParams] = None, multimodal_params: Optional[MultimodalParams] = None, scheduling_params: Optional[SchedulingParams] = None, - cache_salt_id: Optional[int] = None, + cache_salt: Optional[str] = None, arrival_time: Optional[float] = None, priority: float = DEFAULT_REQUEST_PRIORITY, ): @@ -134,7 +136,21 @@ def __init__( self.disaggregated_params = disaggregated_params self.trace_headers = trace_headers self.scheduling_params = scheduling_params - self.cache_salt_id = cache_salt_id + if cache_salt is not None: + if not isinstance(cache_salt, str): + raise TypeError( + f"cache_salt must be str or None, got {type(cache_salt).__name__}" + ) + # The C++ side validates against UTF-8 byte length, so do the same here + # (Python `len()` would count Unicode code points, which can pass this + # guard but fail at C++ dispatch for non-ASCII salts). + cache_salt_byte_len = len(cache_salt.encode("utf-8")) + if cache_salt_byte_len > self.MAX_CACHE_SALT_LEN: + raise ValueError( + f"cache_salt UTF-8 byte length ({cache_salt_byte_len}) " + f"exceeds the maximum supported length " + f"({self.MAX_CACHE_SALT_LEN}).") + self.cache_salt = cache_salt self.arrival_time = arrival_time if not (0.0 <= priority <= 1.0): raise ValueError( diff --git a/tensorrt_llm/inputs/__init__.py b/tensorrt_llm/inputs/__init__.py index 4f2d4cc7e99b..3b9a5d51d052 100644 --- a/tensorrt_llm/inputs/__init__.py +++ b/tensorrt_llm/inputs/__init__.py @@ -24,8 +24,7 @@ async_load_audio, async_load_image, async_load_video, convert_image_mode, default_multimodal_input_loader, encode_base64_content_from_url, encode_base64_image, - get_cache_salt_id, load_base64_image_embeds, load_image, - load_video) + load_base64_image_embeds, load_image, load_video) # yapf: enable @@ -69,7 +68,6 @@ "encode_base64_image", "load_image", "load_video", - "get_cache_salt_id", "compute_retained_tokens_count", "compute_retained_tokens_from_tubelet_budget", "compute_retention_mask", diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 59483e625790..eaa98cad44ef 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -27,8 +27,7 @@ _safe_aiohttp_get, _safe_request_get) from tensorrt_llm.inputs.media_io import \ convert_image_mode as convert_image_mode -from tensorrt_llm.inputs.multimodal import (MultimodalServerConfig, - default_hasher) +from tensorrt_llm.inputs.multimodal import MultimodalServerConfig from tensorrt_llm.inputs.multimodal_data import \ BaseModalityData as BaseModalityData from tensorrt_llm.inputs.multimodal_data import VideoData as VideoData @@ -894,14 +893,3 @@ def convert_to_conversation_message( inputs.append(input) return inputs - - -def get_cache_salt_id(cache_salt: str) -> int: - b = cache_salt.encode("utf-8") - h = default_hasher(b).digest(length=8) - cache_salt_id = int.from_bytes(h, "little", signed=False) - if cache_salt_id < 0 or cache_salt_id >= (1 << 64): - raise ValueError( - f"cache_salt_id must be in [0, 2**64 - 1], got {cache_salt_id}.") - - return cache_salt_id diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index d8c471d8624e..f99aa1593c9f 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -36,7 +36,7 @@ from ..executor.utils import (RequestError, create_mpi_comm_session, get_spawn_proxy_process_env) from ..inputs import (PromptInputs, create_input_processor, - create_input_processor_with_hash, get_cache_salt_id, + create_input_processor_with_hash, maybe_compute_mm_embed_cumsum, prompt_inputs) from ..logger import logger from ..sampling_params import SamplingParams @@ -476,8 +476,6 @@ def generate_async( sampling_params = self._prepare_sampling_params(sampling_params) - cache_salt_id = get_cache_salt_id( - cache_salt) if cache_salt is not None else None # With pytorch backend, py_executor has logic to handle max_tokens of 1, # so set to 1 to avoid allocating unnecessary KV cache blocks for single request # TODO: Also support for trt backend @@ -520,7 +518,7 @@ def generate_async( postproc_params=_postproc_params, multimodal_params=multimodal_params, scheduling_params=scheduling_params, - cache_salt_id=cache_salt_id, + cache_salt=cache_salt, arrival_time=arrival_time, priority=priority, ) diff --git a/tests/unittest/llmapi/test_llm_kv_cache_events.py b/tests/unittest/llmapi/test_llm_kv_cache_events.py index 5831c5b7a57a..ee002905d26c 100644 --- a/tests/unittest/llmapi/test_llm_kv_cache_events.py +++ b/tests/unittest/llmapi/test_llm_kv_cache_events.py @@ -109,6 +109,9 @@ def test_kv_cache_event_data_serialization(): # Verify mm_keys field exists (empty for text-only requests) assert "mm_keys" in serialized_event[0]["data"]["blocks"][0] assert serialized_event[0]["data"]["blocks"][0]["mm_keys"] == [] + # Verify cache_salt field exists (None for unsalted requests) + assert "cache_salt" in serialized_event[0]["data"]["blocks"][0] + assert serialized_event[0]["data"]["blocks"][0]["cache_salt"] is None req2 = create_llm_request(1, [1, 2, 3, 4, 5]) kv_cache_manager.impl.add_sequence_batch( @@ -779,7 +782,7 @@ def test_mm_keys_in_stored_events(): events = llm.get_kv_cache_events(5) - # Find stored events and verify mm_keys field + # Find stored events and verify mm_keys and cache_salt fields for event in events: if event and event["data"]["type"] == "stored": blocks = event["data"]["blocks"] @@ -789,6 +792,86 @@ def test_mm_keys_in_stored_events(): assert isinstance(block["mm_keys"], list) # For text-only requests, mm_keys should be empty assert block["mm_keys"] == [] + # cache_salt should be present (None for unsalted requests) + assert "cache_salt" in block + assert block["cache_salt"] is None + + +def test_cache_salt_in_stored_events(): + """Test that cache_salt string is preserved in stored block events.""" + llm = create_llm() + sampling_params = SamplingParams(max_tokens=6, temperature=0.01) + prompt = "Hello, my name is" + + _ = llm.generate(prompt, + sampling_params=sampling_params, + cache_salt="tenant-A") + + events = llm.get_kv_cache_events(5) + + # Find stored events and verify cache_salt field + found_stored = False + for event in events: + if event and event["data"]["type"] == "stored": + found_stored = True + blocks = event["data"]["blocks"] + for block in blocks: + assert "cache_salt" in block + assert block["cache_salt"] == "tenant-A" + + assert found_stored, "No stored events found" + + +def test_cache_salt_max_length_validation(): + """cache_salt longer than MAX_CACHE_SALT_LEN UTF-8 bytes is rejected.""" + from tensorrt_llm.executor.request import GenerationRequest + + max_len = GenerationRequest.MAX_CACHE_SALT_LEN + sampling_params = SamplingParams() + + # ASCII salt at the limit is accepted. + GenerationRequest(prompt_token_ids=[1, 2, 3], + sampling_params=sampling_params, + cache_salt="a" * max_len) + + # ASCII salt one byte over the limit is rejected. + with pytest.raises(ValueError, match="cache_salt UTF-8 byte length"): + GenerationRequest(prompt_token_ids=[1, 2, 3], + sampling_params=sampling_params, + cache_salt="a" * (max_len + 1)) + + # Non-ASCII salt: each character is 3 UTF-8 bytes. A salt whose + # `len()` is below the limit but whose UTF-8 byte count exceeds it + # must be rejected (this is the case Python's len()-based check missed). + char_count = (max_len // 3) + 1 # len() is well below max_len + salt = "中" * char_count # Chinese character, 3 UTF-8 bytes each + assert len(salt) <= max_len + assert len(salt.encode("utf-8")) > max_len + with pytest.raises(ValueError, match="cache_salt UTF-8 byte length"): + GenerationRequest(prompt_token_ids=[1, 2, 3], + sampling_params=sampling_params, + cache_salt=salt) + + +def test_non_ascii_cache_salt_in_stored_events(): + """Test that a non-ASCII cache_salt string is preserved in stored block events.""" + llm = create_llm() + sampling_params = SamplingParams(max_tokens=6, temperature=0.01) + prompt = "Hello, my name is" + salt = "tenant-中文" # mixed ASCII + Chinese + + _ = llm.generate(prompt, sampling_params=sampling_params, cache_salt=salt) + + events = llm.get_kv_cache_events(5) + + found_stored = False + for event in events: + if event and event["data"]["type"] == "stored": + found_stored = True + for block in event["data"]["blocks"]: + assert block.get("cache_salt") == salt + + assert found_stored, "No stored events found" def test_expected_kv_cache_events():