diff --git a/.github/actions/setup-tokenspeed/action.yml b/.github/actions/setup-tokenspeed/action.yml new file mode 100644 index 000000000..547ea35fd --- /dev/null +++ b/.github/actions/setup-tokenspeed/action.yml @@ -0,0 +1,13 @@ +name: 'Setup TokenSpeed Backend' +description: 'Create Python venv and install TokenSpeed (engine + kernel + scheduler) from source.' + +runs: + using: 'composite' + steps: + - name: Setup Python venv + shell: bash + run: bash scripts/ci_setup_python_venv.sh + + - name: Install TokenSpeed + shell: bash + run: bash scripts/ci_install_tokenspeed.sh diff --git a/.github/workflows/e2e-gpu-job.yml b/.github/workflows/e2e-gpu-job.yml index bc05403ee..29b693899 100644 --- a/.github/workflows/e2e-gpu-job.yml +++ b/.github/workflows/e2e-gpu-job.yml @@ -6,7 +6,7 @@ on: engine: required: true type: string - description: "Engine to test: sglang, vllm, or trtllm" + description: "Engine to test: sglang, vllm, trtllm, or tokenspeed" gpu_tier: required: true type: string @@ -68,6 +68,10 @@ jobs: if: inputs.engine == 'trtllm' uses: ./.github/actions/setup-trtllm + - name: Setup TokenSpeed backend + if: inputs.engine == 'tokenspeed' + uses: ./.github/actions/setup-tokenspeed + # Artifact downloads - name: Download wheel artifact uses: actions/download-artifact@v8 diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index 29621368b..e8ff049b8 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -390,6 +390,7 @@ jobs: - 'scripts/ci_setup_python_venv.sh' - 'scripts/ci_install_sglang.sh' - 'scripts/ci_install_vllm.sh' + - 'scripts/ci_install_tokenspeed.sh' - 'scripts/ci_install_e2e_deps.sh' - 'scripts/ci_killall_sglang.sh' - 'scripts/ci_build_wheel.sh' @@ -404,6 +405,7 @@ jobs: - 'e2e_test/router/**' - 'scripts/ci_install_vllm.sh' - 'scripts/ci_install_trtllm.sh' + - 'scripts/ci_install_tokenspeed.sh' agentic: - 'crates/mcp/**' - 'crates/data_connector/**' @@ -445,6 +447,10 @@ jobs: timeout: 20 - engine: trtllm timeout: 90 + # TokenSpeed builds kernel (CUDA) + scheduler (C++/CMake) from + # source, so first run takes ~30 min; cached runs are faster. + - engine: tokenspeed + timeout: 60 uses: ./.github/workflows/e2e-gpu-job.yml with: engine: ${{ matrix.engine }} @@ -555,6 +561,11 @@ jobs: timeout: 20 - engine: trtllm timeout: 30 + # Picks up TestChatCompletionGptOss (gpt-oss-20b, ``@pytest.mark.gpu(2)``) + # on the tokenspeed engine; the 1-GPU job collected the test class but + # pytest skipped it at collection because the runner only had 1 GPU. + - engine: tokenspeed + timeout: 60 uses: ./.github/workflows/e2e-gpu-job.yml with: engine: ${{ matrix.engine }} diff --git a/.github/workflows/release-pypi-dev.yml b/.github/workflows/release-pypi-dev.yml new file mode 100644 index 000000000..2ce07deca --- /dev/null +++ b/.github/workflows/release-pypi-dev.yml @@ -0,0 +1,449 @@ +name: Release SMG dev wheels (whl index) + +# Builds dev wheels for smg, smg-grpc-proto, and smg-grpc-servicer and +# publishes them to lightseekorg/whl releases + simple package indexes. +# Requires a TOKENSPEED_GITHUB_TOKEN secret with write access to lightseekorg/whl. +# Avoids the 10 GB PyPI project quota and lets us delete old dev releases freely. +# Prod releases continue to go to PyPI via release-pypi.yml / release-grpc.yml. + +on: + workflow_dispatch: + inputs: + release_smg: + description: "Build & release smg (Rust wheel)" + type: boolean + default: true + release_servicer: + description: "Build & release smg-grpc-servicer" + type: boolean + default: true + release_proto: + description: "Build & release smg-grpc-proto" + type: boolean + default: true + # Self-validating: PRs that touch this workflow file auto-build (no release + # is created on pull_request runs; they are pure dry-runs for CI). + pull_request: + paths: + - .github/workflows/release-pypi-dev.yml + +permissions: + contents: read + +jobs: + prepare: + name: Compute dev versions + runs-on: ubuntu-latest + outputs: + smg_version: ${{ steps.compute.outputs.smg_version }} + proto_version: ${{ steps.compute.outputs.proto_version }} + servicer_version: ${{ steps.compute.outputs.servicer_version }} + release_tag: ${{ steps.compute.outputs.release_tag }} + steps: + - uses: actions/checkout@v6 + + - id: compute + run: | + set -euo pipefail + SUFFIX="dev${{ github.run_number }}" + RELEASE_TAG="smg-dev-${{ github.run_number }}" + + read_version() { + grep -m1 '^version = ' "$1" | sed 's/version = "\(.*\)"/\1/' + } + + # Bump patch so dev sorts AFTER current stable in PEP 440 ordering. + bump_patch() { + local IFS=. + # shellcheck disable=SC2206 + local parts=($1) + parts[2]=$(( ${parts[2]} + 1 )) + echo "${parts[0]}.${parts[1]}.${parts[2]}" + } + + SMG_VERSION="$(bump_patch "$(read_version bindings/python/pyproject.toml)").${SUFFIX}" + PROTO_VERSION="$(bump_patch "$(read_version crates/grpc_client/python/pyproject.toml)").${SUFFIX}" + SERVICER_VERSION="$(bump_patch "$(read_version grpc_servicer/pyproject.toml)").${SUFFIX}" + + echo "smg_version=${SMG_VERSION}" >> "$GITHUB_OUTPUT" + echo "proto_version=${PROTO_VERSION}" >> "$GITHUB_OUTPUT" + echo "servicer_version=${SERVICER_VERSION}" >> "$GITHUB_OUTPUT" + echo "release_tag=${RELEASE_TAG}" >> "$GITHUB_OUTPUT" + + { + echo "## Planned dev versions" + echo "| Package | Version | Will build? |" + echo "|---|---|---|" + echo "| smg | ${SMG_VERSION} | ${{ github.event_name == 'pull_request' || inputs.release_smg }} |" + echo "| smg-grpc-proto | ${PROTO_VERSION} | ${{ github.event_name == 'pull_request' || inputs.release_proto }} |" + echo "| smg-grpc-servicer | ${SERVICER_VERSION} | ${{ github.event_name == 'pull_request' || inputs.release_servicer }} |" + echo "" + echo "Release tag (workflow_dispatch only): \`${RELEASE_TAG}\`" + } >> "$GITHUB_STEP_SUMMARY" + + # ── smg (Rust wheel via maturin) — slim matrix: manylinux x86_64 + sdist ── + build-smg: + name: Build smg dev wheel + needs: prepare + if: ${{ github.event_name == 'pull_request' || inputs.release_smg }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + with: + path: smg-repo + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.13" + + # Patches Python metadata + Rust runtime version so the wheel reports a + # consistent dev version everywhere — smg.__version__, the smg.smg_rs + # banner from model_gateway::build.rs, and pyproject.toml. PEP 440 + # (1.4.2.dev42) for Python, semver (1.4.2-dev42) for Cargo. Other + # workspace crates reference model_gateway via path with no version pin. + - name: Patch dev version (Python + model_gateway Cargo) + working-directory: smg-repo + env: + SMG_VERSION: ${{ needs.prepare.outputs.smg_version }} + run: | + set -euo pipefail + SMG_CARGO_VERSION="${SMG_VERSION/.dev/-dev}" + + sed -i.bak "s/^version = \".*\"/version = \"${SMG_VERSION}\"/" bindings/python/pyproject.toml + sed -i.bak "s/__version__ = \".*\"/__version__ = \"${SMG_VERSION}\"/" bindings/python/src/smg/version.py + sed -i.bak "s/^version = \".*\"/version = \"${SMG_CARGO_VERSION}\"/" model_gateway/Cargo.toml + rm -f bindings/python/pyproject.toml.bak bindings/python/src/smg/version.py.bak model_gateway/Cargo.toml.bak + + echo "Patched smg version:" + echo " pyproject.toml + __version__: ${SMG_VERSION}" + echo " model_gateway/Cargo.toml: ${SMG_CARGO_VERSION}" + + - name: Build manylinux x86_64 wheels + uses: PyO3/maturin-action@v1 + with: + working-directory: smg-repo/bindings/python + target: x86_64 + manylinux: auto + # PR runs build a single interpreter (3.12) to keep dry-run fast + # (~6 min vs ~25 min); workflow_dispatch builds the full matrix. + args: --release --out dist --features vendored-openssl --interpreter ${{ github.event_name == 'pull_request' && '3.12' || '3.10 3.11 3.12 3.13' }} + rust-toolchain: stable + before-script-linux: | + if command -v yum &> /dev/null; then + yum update -y && yum install -y wget unzip gcc gcc-c++ perl-core make + elif command -v apt-get &> /dev/null; then + apt-get update && apt-get install -y wget unzip gcc g++ perl make + fi + (cd /tmp && \ + wget https://github.com/protocolbuffers/protobuf/releases/download/v32.0/protoc-32.0-linux-x86_64.zip && \ + unzip protoc-32.0-linux-x86_64.zip -d /usr/local && \ + rm protoc-32.0-linux-x86_64.zip) + protoc --version + + # The manylinux wheel build above runs inside Docker as root; dist/ ends + # up root-owned. Reclaim ownership so the host-side sdist step can write. + - name: Reclaim dist ownership from manylinux container + run: | + sudo chown -R "$(whoami):$(whoami)" smg-repo/bindings/python/dist + ls -la smg-repo/bindings/python/dist + + - name: Build sdist + uses: PyO3/maturin-action@v1 + with: + working-directory: smg-repo/bindings/python + command: sdist + args: --out dist + rust-toolchain: stable + + - name: Check packages + run: | + pip install -U twine + twine check --strict smg-repo/bindings/python/dist/* + + - uses: actions/upload-artifact@v7 + with: + name: smg-dev-dist + path: smg-repo/bindings/python/dist/ + + # ── smg-grpc-proto (pure Python) ── + build-proto: + name: Build smg-grpc-proto dev + needs: prepare + if: ${{ github.event_name == 'pull_request' || inputs.release_proto }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.12" + - name: Install build deps + run: pip install build twine grpcio-tools + - name: Patch dev version + env: + PROTO_VERSION: ${{ needs.prepare.outputs.proto_version }} + run: | + set -euo pipefail + sed -i.bak "s/^version = \".*\"/version = \"${PROTO_VERSION}\"/" \ + crates/grpc_client/python/pyproject.toml + rm -f crates/grpc_client/python/pyproject.toml.bak + - name: Copy proto files (replace symlink with real files) + run: | + rm -f crates/grpc_client/python/smg_grpc_proto/proto + mkdir -p crates/grpc_client/python/smg_grpc_proto/proto + cp crates/grpc_client/proto/*.proto crates/grpc_client/python/smg_grpc_proto/proto/ + - name: Build package + run: cd crates/grpc_client/python && python -m build + - name: Check package + run: twine check --strict crates/grpc_client/python/dist/* + - uses: actions/upload-artifact@v7 + with: + name: proto-dev-dist + path: crates/grpc_client/python/dist/ + + # ── smg-grpc-servicer (pure Python) ── + build-servicer: + name: Build smg-grpc-servicer dev + needs: prepare + if: ${{ github.event_name == 'pull_request' || inputs.release_servicer }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.12" + - name: Install build deps + run: pip install build twine + # When proto is also being built in this run, pin the servicer's proto + # dep to the exact dev version so consumers get a coherent dev set. + - name: Patch dev version (+ pin proto dep if releasing proto together) + env: + SERVICER_VERSION: ${{ needs.prepare.outputs.servicer_version }} + PROTO_VERSION: ${{ needs.prepare.outputs.proto_version }} + RELEASE_PROTO: ${{ github.event_name == 'pull_request' || inputs.release_proto }} + run: | + set -euo pipefail + sed -i.bak "s/^version = \".*\"/version = \"${SERVICER_VERSION}\"/" \ + grpc_servicer/pyproject.toml + if [ "${RELEASE_PROTO}" = "true" ]; then + # Matches both the core dep ("smg-grpc-proto>=0.4.6") and the + # mlx extra ("smg-grpc-proto>=0.4.7"). + sed -i.bak -E "s|smg-grpc-proto>=[0-9.]+|smg-grpc-proto==${PROTO_VERSION}|g" \ + grpc_servicer/pyproject.toml + echo "Pinned smg-grpc-proto==${PROTO_VERSION} for dev coherence" + fi + rm -f grpc_servicer/pyproject.toml.bak + - name: Build package + run: cd grpc_servicer && python -m build + - name: Check package + run: twine check --strict grpc_servicer/dist/* + - uses: actions/upload-artifact@v7 + with: + name: servicer-dev-dist + path: grpc_servicer/dist/ + + # ── Publish dev wheels to lightseekorg/whl and update simple indexes ── + # Only fires for workflow_dispatch — pull_request runs validate the build + # but never publish artifacts or update indexes. + release: + name: Publish to whl index + needs: [prepare, build-smg, build-proto, build-servicer] + # Conditions: + # - Only on workflow_dispatch (no releases on pull_request runs). + # - prepare must have succeeded so we have a tag/versions. + # - No requested build may have failed. + # - At least one build must have produced artifacts. + if: >- + always() + && github.event_name == 'workflow_dispatch' + && github.repository == 'lightseekorg/smg' + && needs.prepare.result == 'success' + && needs.build-smg.result != 'failure' + && needs.build-proto.result != 'failure' + && needs.build-servicer.result != 'failure' + && ( + needs.build-smg.result == 'success' + || needs.build-proto.result == 'success' + || needs.build-servicer.result == 'success' + ) + runs-on: ubuntu-latest + steps: + - name: Download all dev artifacts + uses: actions/download-artifact@v8 + with: + path: dist + pattern: "*-dev-dist" + merge-multiple: true + + - name: List collected assets + run: ls -lh dist/ + + - name: Check out smg repo + uses: actions/checkout@v6 + with: + path: smg-repo + + - name: Check out whl index repo + uses: actions/checkout@v6 + with: + repository: lightseekorg/whl + ref: gh-pages + path: whl-repo + token: ${{ secrets.TOKENSPEED_GITHUB_TOKEN }} + + - name: Create whl release and upload assets + env: + GH_TOKEN: ${{ secrets.TOKENSPEED_GITHUB_TOKEN }} + RELEASE_TAG: ${{ needs.prepare.outputs.release_tag }} + WHL_REPO: lightseekorg/whl + SMG_VERSION: ${{ needs.prepare.outputs.smg_version }} + PROTO_VERSION: ${{ needs.prepare.outputs.proto_version }} + SERVICER_VERSION: ${{ needs.prepare.outputs.servicer_version }} + run: | + set -euo pipefail + + notes_file="$(mktemp)" + cat > "${notes_file}" </dev/null 2>&1; then + gh release edit "${RELEASE_TAG}" \ + --repo "${WHL_REPO}" \ + --title "smg dev ${{ github.run_number }}" \ + --prerelease \ + --notes-file "${notes_file}" + else + gh release create "${RELEASE_TAG}" \ + --repo "${WHL_REPO}" \ + --title "smg dev ${{ github.run_number }}" \ + --prerelease \ + --notes-file "${notes_file}" + fi + + gh release upload "${RELEASE_TAG}" dist/* --repo "${WHL_REPO}" --clobber + + - name: Update whl package indexes + env: + RELEASE_TAG: ${{ needs.prepare.outputs.release_tag }} + run: | + set -euo pipefail + + cd whl-repo + git config user.name "github-actions[bot]" + git config user.email "41898282+github-actions[bot]@users.noreply.github.com" + + rm -rf wheelhouse + mkdir -p wheelhouse/smg wheelhouse/smg-grpc-proto wheelhouse/smg-grpc-servicer + cp ../dist/smg-*.whl wheelhouse/smg/ 2>/dev/null || true + cp ../dist/smg_grpc_proto-*.whl wheelhouse/smg-grpc-proto/ 2>/dev/null || true + cp ../dist/smg_grpc_servicer-*.whl wheelhouse/smg-grpc-servicer/ 2>/dev/null || true + + # Make workflow re-runs idempotent for the same release tag. The + # shared update_whl_index.py script appends wheel links by design. + for index_file in \ + cu129/smg/index.html \ + cu129/smg-grpc-proto/index.html \ + cu129/smg-grpc-servicer/index.html \ + cu130/smg/index.html \ + cu130/smg-grpc-proto/index.html \ + cu130/smg-grpc-servicer/index.html \ + rocm7.2/smg/index.html \ + rocm7.2/smg-grpc-proto/index.html \ + rocm7.2/smg-grpc-servicer/index.html; do + [ -f "${index_file}" ] && sed -i "\#/${RELEASE_TAG}/#d" "${index_file}" + done + + for cuda in 129 130; do + python3 ../smg-repo/scripts/update_whl_index.py \ + --package smg \ + --cuda "${cuda}" \ + --release-tag "${RELEASE_TAG}" \ + --wheel-dir wheelhouse/smg \ + --whl-repo-dir . + python3 ../smg-repo/scripts/update_whl_index.py \ + --package smg-grpc-proto \ + --cuda "${cuda}" \ + --release-tag "${RELEASE_TAG}" \ + --wheel-dir wheelhouse/smg-grpc-proto \ + --whl-repo-dir . + python3 ../smg-repo/scripts/update_whl_index.py \ + --package smg-grpc-servicer \ + --cuda "${cuda}" \ + --release-tag "${RELEASE_TAG}" \ + --wheel-dir wheelhouse/smg-grpc-servicer \ + --whl-repo-dir . + done + + for package in smg smg-grpc-proto smg-grpc-servicer; do + python3 ../smg-repo/scripts/update_whl_index.py \ + --package "${package}" \ + --rocm 7.2 \ + --release-tag "${RELEASE_TAG}" \ + --wheel-dir "wheelhouse/${package}" \ + --whl-repo-dir . + done + + rm -rf wheelhouse + git add cu129 cu130 rocm7.2 + if git diff --cached --quiet; then + echo "No whl index changes to commit." + else + git commit -s -m "Add smg dev ${{ github.run_number }} wheels" + git push origin gh-pages + fi + + # ── Final summary in the run UI ── + summary: + name: Summary + needs: + - prepare + - build-smg + - build-proto + - build-servicer + - release + if: always() + runs-on: ubuntu-latest + steps: + - run: | + { + if [ "${{ github.event_name }}" = "pull_request" ]; then + echo "## Dev build summary (pull_request — no GitHub Release created)" + else + echo "## Dev release summary" + fi + echo "" + echo "| Package | Version | Build |" + echo "|---|---|---|" + echo "| smg | ${{ needs.prepare.outputs.smg_version }} | ${{ needs.build-smg.result }} |" + echo "| smg-grpc-proto | ${{ needs.prepare.outputs.proto_version }} | ${{ needs.build-proto.result }} |" + echo "| smg-grpc-servicer | ${{ needs.prepare.outputs.servicer_version }} | ${{ needs.build-servicer.result }} |" + echo "" + if [ "${{ github.event_name }}" = "workflow_dispatch" ] && [ "${{ needs.release.result }}" = "success" ]; then + TAG="${{ needs.prepare.outputs.release_tag }}" + REPO="${{ github.repository }}" + echo "### Release" + echo "" + echo "https://github.com/lightseekorg/whl/releases/tag/${TAG}" + echo "" + echo "### Install" + echo '```bash' + echo "pip install smg smg-grpc-servicer smg-grpc-proto \\" + echo " --extra-index-url https://lightseek.org/whl/cu129/" + echo '```' + fi + } >> "$GITHUB_STEP_SUMMARY" diff --git a/crates/grpc_client/build.rs b/crates/grpc_client/build.rs index 9809b80b1..8a8730c1a 100644 --- a/crates/grpc_client/build.rs +++ b/crates/grpc_client/build.rs @@ -2,6 +2,7 @@ fn main() -> Result<(), Box> { // Rebuild triggers println!("cargo:rerun-if-changed=proto/common.proto"); println!("cargo:rerun-if-changed=proto/sglang_scheduler.proto"); + println!("cargo:rerun-if-changed=proto/tokenspeed_scheduler.proto"); println!("cargo:rerun-if-changed=proto/vllm_engine.proto"); println!("cargo:rerun-if-changed=proto/trtllm_service.proto"); println!("cargo:rerun-if-changed=proto/mlx_engine.proto"); @@ -19,8 +20,8 @@ fn main() -> Result<(), Box> { .build_client(true) .extern_path(".smg.grpc.common", "crate::common_proto") .type_attribute("GetModelInfoResponse", "#[derive(serde::Serialize)]") - // vllm + trtllm ServerInfo have only primitive fields. - // sglang's contains prost_types::{Struct,Timestamp} so it's handled separately. + // Some ServerInfo protos contain prost_types::{Struct, Timestamp}; + // those are handled separately at the wrapper layer. .type_attribute( "vllm.grpc.engine.GetServerInfoResponse", "#[derive(serde::Serialize)]", @@ -40,6 +41,7 @@ fn main() -> Result<(), Box> { "proto/vllm_engine.proto", "proto/trtllm_service.proto", "proto/mlx_engine.proto", + "proto/tokenspeed_scheduler.proto", ], &["proto"], )?; diff --git a/crates/grpc_client/proto/tokenspeed_scheduler.proto b/crates/grpc_client/proto/tokenspeed_scheduler.proto new file mode 100644 index 000000000..8062050b9 --- /dev/null +++ b/crates/grpc_client/proto/tokenspeed_scheduler.proto @@ -0,0 +1,265 @@ +syntax = "proto3"; + +package tokenspeed.grpc.scheduler; + +import "google/protobuf/timestamp.proto"; +import "google/protobuf/struct.proto"; + +// TokenSpeed scheduler gRPC service. Fully self-contained wire definition. +// Trimmed to text-generation only (no embed, no multimodal, no +// PD-disaggregated, no LoRA, no hidden-state forwarding). +service TokenSpeedScheduler { + rpc Generate(GenerateRequest) returns (stream GenerateResponse); + rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse); + rpc Abort(AbortRequest) returns (AbortResponse); + rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoResponse); + rpc GetServerInfo(GetServerInfoRequest) returns (GetServerInfoResponse); + rpc GetLoads(GetLoadsRequest) returns (GetLoadsResponse); +} + +// ===================== +// Sampling +// ===================== + +// Sampling scalars are `optional` so the servicer can distinguish +// "client set 0" from "client unset" via `HasField()`. `min_new_tokens` +// stays non-optional because 0 is its semantic "no minimum" sentinel. +message SamplingParams { + optional float temperature = 1; + optional float top_p = 2; + optional int32 top_k = 3; + optional float min_p = 4; + optional float frequency_penalty = 5; + optional float presence_penalty = 6; + optional float repetition_penalty = 7; + + optional uint32 max_new_tokens = 8; + uint32 min_new_tokens = 9; + + repeated string stop = 10; + repeated uint32 stop_token_ids = 11; + bool ignore_eos = 12; + + bool skip_special_tokens = 13; + bool spaces_between_special_tokens = 14; + + // Number of samples (n in OpenAI API). + uint32 n = 15; + + // Per-token logit bias. + map logit_bias = 16; + + // Structured generation. Currently xfailed in e2e (tokenspeed#361). + oneof constraint { + string regex = 17; + string json_schema = 18; + string ebnf_grammar = 19; + string structural_tag = 20; + } + + // Keep the trailing matched stop token in `output_ids`. + bool no_stop_trim = 22; + + // Escape hatch for backend-specific knobs without bumping the proto. + google.protobuf.Struct custom_params = 21; +} + +// ===================== +// Generate +// ===================== + +message GenerateRequest { + string request_id = 1; + + // Tokenized input (router does its own tokenization). + TokenizedInput tokenized = 2; + + SamplingParams sampling_params = 3; + + // Logprob options. + bool return_logprob = 4; + // Optional: unset = no input logprobs; explicit 0 = "from start". + optional int32 logprob_start_len = 5; + int32 top_logprobs_num = 6; + repeated uint32 token_ids_logprob = 7; + + // Whether the client wants stream chunks (otherwise: complete-only). + bool stream = 8; +} + +message TokenizedInput { + repeated uint32 input_ids = 1; + string original_text = 2; // cosmetic, for worker logs +} + +message GenerateResponse { + string request_id = 1; + + oneof response { + GenerateStreamChunk chunk = 2; + GenerateComplete complete = 3; + } +} + +message GenerateStreamChunk { + // Generated tokens since the previous chunk. + repeated uint32 token_ids = 1; + + uint32 prompt_tokens = 2; + uint32 completion_tokens = 3; + uint32 cached_tokens = 4; + + OutputLogProbs output_logprobs = 5; + + // For ordering when n>1. + uint32 index = 6; +} + +message GenerateComplete { + repeated uint32 output_ids = 1; + + // OpenAI-compatible: "stop", "length", "abort", "tool_calls". + string finish_reason = 2; + + uint32 prompt_tokens = 3; + uint32 completion_tokens = 4; + uint32 cached_tokens = 5; + + OutputLogProbs output_logprobs = 6; + + // Which stop matched (for clients that care which `stop` triggered). + oneof matched_stop { + uint32 matched_token_id = 7; + string matched_stop_str = 8; + } + + uint32 index = 9; +} + +message OutputLogProbs { + repeated float token_logprobs = 1; + repeated uint32 token_ids = 2; + repeated TopLogProbs top_logprobs = 3; +} + +message TopLogProbs { + repeated float values = 1; + repeated uint32 token_ids = 2; +} + +// ===================== +// Management +// ===================== + +message HealthCheckRequest {} +message HealthCheckResponse { + bool healthy = 1; + string message = 2; +} + +message AbortRequest { + string request_id = 1; + string reason = 2; +} +message AbortResponse { + bool success = 1; + string message = 2; +} + +// ===================== +// Model & Server Info +// ===================== + +message GetModelInfoRequest {} +message GetModelInfoResponse { + string model_path = 1; + string tokenizer_path = 2; + string served_model_name = 3; + string model_type = 4; + repeated string architectures = 5; + + int32 max_context_length = 6; + int32 max_req_input_len = 7; + int32 vocab_size = 8; + + repeated int32 eos_token_ids = 9; + int32 pad_token_id = 10; + int32 bos_token_id = 11; + + string weight_version = 12; + string preferred_sampling_params = 13; // JSON string or empty +} + +message GetServerInfoRequest {} +message GetServerInfoResponse { + google.protobuf.Struct server_args = 1; + google.protobuf.Struct scheduler_info = 2; + + int32 active_requests = 3; + bool is_paused = 4; + double uptime_seconds = 5; + int32 max_total_num_tokens = 6; + + string tokenspeed_version = 7; + google.protobuf.Timestamp start_time = 8; +} + +// ===================== +// Loads +// ===================== + +message GetLoadsRequest { + optional int32 dp_rank = 1; + // Sections: "core" (default), "memory", "queues". Pass "all" for everything. + repeated string include = 2; +} + +message GetLoadsResponse { + string timestamp = 1; + string version = 2; + int32 dp_rank_count = 3; + repeated SchedulerLoad loads = 4; + AggregateMetrics aggregate = 5; +} + +message SchedulerLoad { + int32 dp_rank = 1; + + int32 num_running_reqs = 2; + int32 num_waiting_reqs = 3; + int32 num_total_reqs = 4; + int32 num_used_tokens = 5; + int32 max_total_num_tokens = 6; + int32 max_running_requests = 7; + + double token_usage = 8; + double gen_throughput = 9; + double cache_hit_rate = 10; + double utilization = 11; + + optional MemoryMetrics memory = 12; + optional QueueMetrics queues = 13; +} + +message MemoryMetrics { + double weight_gb = 1; + double kv_cache_gb = 2; + double graph_gb = 3; + int32 token_capacity = 4; +} + +message QueueMetrics { + int32 waiting = 1; + int32 grammar = 2; + int32 paused = 3; + int32 retracted = 4; +} + +message AggregateMetrics { + int32 total_running_reqs = 1; + int32 total_waiting_reqs = 2; + int32 total_reqs = 3; + double avg_token_usage = 4; + double avg_throughput = 5; + double avg_utilization = 6; +} diff --git a/crates/grpc_client/python/smg_grpc_proto/__init__.py b/crates/grpc_client/python/smg_grpc_proto/__init__.py index 6a19b4aea..f7eac4a3e 100644 --- a/crates/grpc_client/python/smg_grpc_proto/__init__.py +++ b/crates/grpc_client/python/smg_grpc_proto/__init__.py @@ -1,4 +1,4 @@ -"""SMG gRPC Proto - Protocol definitions for SGLang, vLLM, TRT-LLM, and MLX.""" +"""SMG gRPC Proto - Protocol definitions for SGLang, TokenSpeed, vLLM, TRT-LLM, and MLX.""" from importlib.metadata import version @@ -14,6 +14,8 @@ sglang_encoder_pb2_grpc, sglang_scheduler_pb2, sglang_scheduler_pb2_grpc, + tokenspeed_scheduler_pb2, + tokenspeed_scheduler_pb2_grpc, trtllm_service_pb2, trtllm_service_pb2_grpc, vllm_engine_pb2, @@ -25,6 +27,8 @@ "sglang_scheduler_pb2_grpc", "sglang_encoder_pb2", "sglang_encoder_pb2_grpc", + "tokenspeed_scheduler_pb2", + "tokenspeed_scheduler_pb2_grpc", "vllm_engine_pb2", "vllm_engine_pb2_grpc", "trtllm_service_pb2", diff --git a/crates/grpc_client/src/lib.rs b/crates/grpc_client/src/lib.rs index 77c4fa5a2..0e33a0169 100644 --- a/crates/grpc_client/src/lib.rs +++ b/crates/grpc_client/src/lib.rs @@ -1,15 +1,14 @@ -//! gRPC clients for SGLang, vLLM, TensorRT-LLM, and MLX backends -//! -//! This crate provides gRPC client implementations for communicating with -//! SGLang scheduler, vLLM engine, TensorRT-LLM engine, and MLX engine backends. +//! gRPC clients for the supported inference backends. pub mod common_proto { #![allow(clippy::all, clippy::absolute_paths, unused_qualifications)] tonic::include_proto!("smg.grpc.common"); } pub mod mlx_engine; +pub mod sampling_params; pub mod sglang_scheduler; pub mod tokenizer_bundle; +pub mod tokenspeed_scheduler; pub mod trtllm_service; pub mod vllm_engine; @@ -18,6 +17,7 @@ use std::sync::Arc; pub use mlx_engine::{proto as mlx_proto, MlxEngineClient}; pub use sglang_scheduler::{proto as sglang_proto, SglangSchedulerClient}; +pub use tokenspeed_scheduler::{tokenspeed_proto, TokenSpeedSchedulerClient}; use tonic::metadata::MetadataMap; pub use trtllm_service::{proto as trtllm_proto, TrtllmServiceClient}; pub use vllm_engine::{proto as vllm_proto, VllmEngineClient}; diff --git a/crates/grpc_client/src/sampling_params.rs b/crates/grpc_client/src/sampling_params.rs new file mode 100644 index 000000000..4a28e7076 --- /dev/null +++ b/crates/grpc_client/src/sampling_params.rs @@ -0,0 +1,361 @@ +//! Backend-neutral OpenAI → sampling-params builders. The return type +//! happens to be [`proto::SamplingParams`] (the most permissive shape +//! today); other backends translate from this at their wire seam. + +use openai_protocol::{ + chat::ChatCompletionRequest, + common::{ResponseFormat, StringOrArray}, + completion::CompletionRequest, + messages::CreateMessageRequest, + responses::ResponsesRequest, + sampling_params::SamplingParams as GenerateSamplingParams, +}; +use tracing::warn; + +use crate::sglang_scheduler::proto; + +/// Build gRPC `SamplingParams` from a `ChatCompletionRequest`. +pub fn build_grpc_sampling_params_from_chat( + request: &ChatCompletionRequest, + tool_call_constraint: Option<(String, String)>, +) -> Result { + let stop_sequences = extract_stop_strings(request); + + let max_new_tokens = request.max_completion_tokens; + + // Hardcode to true: gRPC backends return raw token IDs, not decoded text. + // Detokenization happens on the SMG Rust side (StopDecoder/Sequence). + // + // Note: TokenSpeed's HTTP serving_chat sets this to false when tools are + // present (serving_chat.py:178-179) — but mirroring that on the gRPC + // path measurably HURTS BFCL accuracy. We tested it: simple_python + // dropped from ~88.75 % to 79 %, parallel_multiple from ~84.5 % to + // 60.5 %. With skip_special_tokens=false the engine emits the + // ``<|tool_call_*|>`` special tokens in the raw output stream, and the + // SMG-side detokenizer + kimik2 tool-call parser then double-counts or + // misframes them. Keep it at true so SMG sees normal tokens and + // applies its own parsing. + let skip_special_tokens = true; + + Ok(proto::SamplingParams { + temperature: request.temperature.unwrap_or(1.0), + top_p: request.top_p.unwrap_or(1.0), + top_k: request.top_k.unwrap_or(-1), + min_p: request.min_p.unwrap_or(0.0), + frequency_penalty: request.frequency_penalty.unwrap_or(0.0), + presence_penalty: request.presence_penalty.unwrap_or(0.0), + repetition_penalty: request.repetition_penalty.unwrap_or(1.0), + max_new_tokens, + stop: stop_sequences, + stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), + skip_special_tokens, + spaces_between_special_tokens: true, // Default from Python SamplingParams + ignore_eos: request.ignore_eos, + no_stop_trim: request.no_stop_trim, + n: request.n.unwrap_or(1), + constraint: build_constraint_for_chat(request, tool_call_constraint)?, + ..Default::default() + }) +} + +/// Build gRPC `SamplingParams` from a `ResponsesRequest`. +/// +/// Used by Harmony models only. Regular models use the Chat API path. +/// Constraints come from the Harmony preparation stage (`structural_tag`) +/// or tool handling. +pub fn build_grpc_sampling_params_from_responses( + request: &ResponsesRequest, + constraint: Option<(String, String)>, +) -> Result { + let max_new_tokens = request.max_output_tokens; + + Ok(proto::SamplingParams { + temperature: request.temperature.unwrap_or(1.0), + top_p: request.top_p.unwrap_or(1.0), + top_k: request.top_k, + min_p: request.min_p, + frequency_penalty: request.frequency_penalty.unwrap_or(0.0), + presence_penalty: request.presence_penalty.unwrap_or(0.0), + repetition_penalty: request.repetition_penalty, + max_new_tokens, + stop: vec![], // Does not pass through request.stop yet (follow-up fix) + stop_token_ids: vec![], // Handled by Harmony stop tokens + skip_special_tokens: false, // Keep special tokens for Harmony + spaces_between_special_tokens: true, + ignore_eos: false, + no_stop_trim: false, + n: 1, // Responses API doesn't support n>1 + constraint: build_constraint_for_responses(constraint)?, + ..Default::default() + }) +} + +/// Build gRPC `SamplingParams` from a `CreateMessageRequest` (Anthropic +/// Messages API). +pub fn build_grpc_sampling_params_from_messages( + request: &CreateMessageRequest, + tool_call_constraint: Option<(String, String)>, +) -> Result { + let stop_sequences = request.stop_sequences.clone().unwrap_or_default(); + + // Hardcode to true: gRPC backends return raw token IDs, not decoded text. + let skip_special_tokens = true; + + Ok(proto::SamplingParams { + temperature: request.temperature.unwrap_or(1.0) as f32, + top_p: request.top_p.unwrap_or(1.0) as f32, + top_k: request.top_k.map(|v| v as i32).unwrap_or(-1), + min_p: 0.0, + frequency_penalty: 0.0, + presence_penalty: 0.0, + repetition_penalty: 1.0, + max_new_tokens: Some(request.max_tokens), + stop: stop_sequences, + stop_token_ids: vec![], + skip_special_tokens, + spaces_between_special_tokens: true, + ignore_eos: false, + no_stop_trim: false, + n: 1, + constraint: build_constraint_for_responses(tool_call_constraint)?, + ..Default::default() + }) +} + +/// Build gRPC `SamplingParams` from a `CompletionRequest` +/// (`/v1/completions`). +pub fn build_grpc_sampling_params_from_completion( + request: &CompletionRequest, +) -> Result { + let stop_sequences = match &request.stop { + Some(StringOrArray::String(s)) => vec![s.clone()], + Some(StringOrArray::Array(arr)) => arr.clone(), + None => vec![], + }; + + let constraint = build_single_constraint_from_completion(request)?; + + Ok(proto::SamplingParams { + temperature: request.temperature.unwrap_or(1.0), + top_p: request.top_p.unwrap_or(1.0), + top_k: request.top_k.unwrap_or(-1), + min_p: request.min_p.unwrap_or(0.0), + frequency_penalty: request.frequency_penalty.unwrap_or(0.0), + presence_penalty: request.presence_penalty.unwrap_or(0.0), + repetition_penalty: request.repetition_penalty.unwrap_or(1.0), + max_new_tokens: request.max_tokens, + min_new_tokens: request.min_tokens.unwrap_or(0), + stop: stop_sequences, + stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), + skip_special_tokens: request.skip_special_tokens, + spaces_between_special_tokens: true, + ignore_eos: request.ignore_eos, + no_stop_trim: request.no_stop_trim, + n: request.n.unwrap_or(1), + constraint, + ..Default::default() + }) +} + +/// Build gRPC `SamplingParams` from the plain `GenerateSamplingParams` +/// shape used by `/generate`. +pub fn build_sampling_params_from_plain( + params: Option<&GenerateSamplingParams>, +) -> Result { + let mut sampling = proto::SamplingParams { + temperature: 1.0, + top_p: 1.0, + top_k: -1, + repetition_penalty: 1.0, + n: 1, + skip_special_tokens: true, + spaces_between_special_tokens: true, + ..Default::default() + }; + + let Some(p) = params else { + return Ok(sampling); + }; + + macro_rules! map_field { + ($field:ident) => { + if let Some(val) = p.$field { + sampling.$field = val; + } + }; + } + + map_field!(temperature); + map_field!(top_p); + map_field!(top_k); + map_field!(frequency_penalty); + map_field!(presence_penalty); + map_field!(repetition_penalty); + map_field!(min_p); + map_field!(ignore_eos); + map_field!(skip_special_tokens); + map_field!(no_stop_trim); + + if let Some(stop) = &p.stop { + match stop { + StringOrArray::String(s) => sampling.stop.push(s.clone()), + StringOrArray::Array(arr) => sampling.stop.extend(arr.clone()), + } + } + + if let Some(stop_token_ids) = &p.stop_token_ids { + sampling.stop_token_ids.clone_from(stop_token_ids); + } + + sampling.max_new_tokens = p.max_new_tokens; + + if let Some(min_new_tokens) = p.min_new_tokens { + sampling.min_new_tokens = min_new_tokens; + } + + if let Some(n) = p.n { + sampling.n = n; + } + + sampling.constraint = build_single_constraint_from_plain(p)?; + + Ok(sampling) +} + +// --------------------------------------------------------------------------- +// Constraint helpers +// --------------------------------------------------------------------------- + +fn extract_stop_strings(request: &ChatCompletionRequest) -> Vec { + match &request.stop { + Some(StringOrArray::String(s)) => vec![s.clone()], + Some(StringOrArray::Array(arr)) => arr.clone(), + None => vec![], + } +} + +fn build_constraint_for_chat( + request: &ChatCompletionRequest, + tool_call_constraint: Option<(String, String)>, +) -> Result, String> { + let mut constraints = Vec::new(); + + match &request.response_format { + Some(ResponseFormat::JsonObject) => { + let schema = serde_json::json!({"type": "object"}); + let schema_str = serde_json::to_string(&schema) + .map_err(|e| format!("Failed to serialize JSON schema: {e}"))?; + constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); + } + Some(ResponseFormat::JsonSchema { json_schema }) => { + let schema_str = serde_json::to_string(&json_schema.schema) + .map_err(|e| format!("Failed to serialize JSON schema: {e}"))?; + constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); + } + Some(ResponseFormat::Text) | None => {} + } + + if let Some(ebnf) = &request.ebnf { + constraints.push(proto::sampling_params::Constraint::EbnfGrammar( + ebnf.clone(), + )); + } + + if let Some(regex) = &request.regex { + constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); + } + + // If response_format already set a constraint, drop the tool constraint + // (response_format takes priority over tool-call constraints). + if let Some((constraint_type, constraint_value)) = tool_call_constraint { + if constraints.is_empty() { + let tool_constraint = match constraint_type.as_str() { + "structural_tag" => { + proto::sampling_params::Constraint::StructuralTag(constraint_value) + } + "json_schema" => proto::sampling_params::Constraint::JsonSchema(constraint_value), + "ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value), + "regex" => proto::sampling_params::Constraint::Regex(constraint_value), + _ => return Err(format!("Unknown constraint type: {constraint_type}")), + }; + constraints.push(tool_constraint); + } else { + warn!( + "Constrained decoding is not compatible with tool calls, dropping tool constraint" + ); + } + } + + match constraints.len() { + 0 => Ok(None), + 1 => Ok(constraints.pop()), + _ => Err("Multiple constraints are not allowed.".to_string()), + } +} + +fn build_constraint_for_responses( + constraint: Option<(String, String)>, +) -> Result, String> { + if let Some((constraint_type, constraint_value)) = constraint { + let parsed_constraint = match constraint_type.as_str() { + "structural_tag" => proto::sampling_params::Constraint::StructuralTag(constraint_value), + "json_schema" => proto::sampling_params::Constraint::JsonSchema(constraint_value), + "ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value), + "regex" => proto::sampling_params::Constraint::Regex(constraint_value), + _ => return Err(format!("Unknown constraint type: {constraint_type}")), + }; + Ok(Some(parsed_constraint)) + } else { + Ok(None) + } +} + +fn build_single_constraint_from_completion( + request: &CompletionRequest, +) -> Result, String> { + let mut constraints = Vec::new(); + if let Some(json_schema) = &request.json_schema { + constraints.push(proto::sampling_params::Constraint::JsonSchema( + json_schema.clone(), + )); + } + if let Some(regex) = &request.regex { + constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); + } + if let Some(ebnf) = &request.ebnf { + constraints.push(proto::sampling_params::Constraint::EbnfGrammar( + ebnf.clone(), + )); + } + + match constraints.len() { + 0 => Ok(None), + 1 => Ok(constraints.pop()), + _ => Err("Multiple structured constraints are not allowed".to_string()), + } +} + +fn build_single_constraint_from_plain( + params: &GenerateSamplingParams, +) -> Result, String> { + let mut constraints = Vec::new(); + if let Some(json_schema) = ¶ms.json_schema { + constraints.push(proto::sampling_params::Constraint::JsonSchema( + json_schema.clone(), + )); + } + if let Some(regex) = ¶ms.regex { + constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); + } + if let Some(ebnf) = ¶ms.ebnf { + constraints.push(proto::sampling_params::Constraint::EbnfGrammar( + ebnf.clone(), + )); + } + + match constraints.len() { + 0 => Ok(None), + 1 => Ok(constraints.pop()), + _ => Err("Multiple structured constraints are not allowed".to_string()), + } +} diff --git a/crates/grpc_client/src/sglang_scheduler.rs b/crates/grpc_client/src/sglang_scheduler.rs index edfc43e0f..183d04bac 100644 --- a/crates/grpc_client/src/sglang_scheduler.rs +++ b/crates/grpc_client/src/sglang_scheduler.rs @@ -9,13 +9,8 @@ use std::{ }; use openai_protocol::{ - chat::ChatCompletionRequest, - common::{ResponseFormat, StringOrArray}, - completion::CompletionRequest, - generate::GenerateRequest, - messages::CreateMessageRequest, - responses::ResponsesRequest, - sampling_params::SamplingParams as GenerateSamplingParams, + chat::ChatCompletionRequest, completion::CompletionRequest, generate::GenerateRequest, + messages::CreateMessageRequest, responses::ResponsesRequest, }; use tonic::{transport::Channel, Request, Streaming}; use tracing::{debug, warn}; @@ -337,8 +332,10 @@ impl SglangSchedulerClient { tool_call_constraint: Option<(String, String)>, // (constraint_type, constraint_value) ) -> Result { // Build sampling params - let sampling_params = - Self::build_grpc_sampling_params_from_chat(body, tool_call_constraint)?; + let sampling_params = crate::sampling_params::build_grpc_sampling_params_from_chat( + body, + tool_call_constraint, + )?; let grpc_request = proto::GenerateRequest { request_id, @@ -371,8 +368,9 @@ impl SglangSchedulerClient { original_text: Option, token_ids: Vec, ) -> Result { - let sampling_params = - Self::build_sampling_params_from_plain(body.sampling_params.as_ref())?; + let sampling_params = crate::sampling_params::build_sampling_params_from_plain( + body.sampling_params.as_ref(), + )?; let grpc_request = proto::GenerateRequest { request_id, @@ -411,7 +409,8 @@ impl SglangSchedulerClient { constraint: Option<(String, String)>, ) -> Result { // Build sampling params from ResponsesRequest - let sampling_params = Self::build_grpc_sampling_params_from_responses(body, constraint)?; + let sampling_params = + crate::sampling_params::build_grpc_sampling_params_from_responses(body, constraint)?; let grpc_request = proto::GenerateRequest { request_id, @@ -433,171 +432,6 @@ impl SglangSchedulerClient { Ok(grpc_request) } - /// Build gRPC SamplingParams from ChatCompletionRequest - fn build_grpc_sampling_params_from_chat( - request: &ChatCompletionRequest, - tool_call_constraint: Option<(String, String)>, - ) -> Result { - let stop_sequences = Self::extract_stop_strings(request); - - let max_new_tokens = request.max_completion_tokens; - - // Hardcode to true: gRPC backends return raw token IDs, not decoded text. - // Detokenization happens on the SMG Rust side (StopDecoder/Sequence). - let skip_special_tokens = true; - - Ok(proto::SamplingParams { - temperature: request.temperature.unwrap_or(1.0), - top_p: request.top_p.unwrap_or(1.0), - top_k: request.top_k.unwrap_or(-1), - min_p: request.min_p.unwrap_or(0.0), - frequency_penalty: request.frequency_penalty.unwrap_or(0.0), - presence_penalty: request.presence_penalty.unwrap_or(0.0), - repetition_penalty: request.repetition_penalty.unwrap_or(1.0), - max_new_tokens, - stop: stop_sequences, - stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), - skip_special_tokens, - spaces_between_special_tokens: true, // Default from Python SamplingParams - ignore_eos: request.ignore_eos, - no_stop_trim: request.no_stop_trim, - n: request.n.unwrap_or(1), - constraint: Self::build_constraint_for_chat(request, tool_call_constraint)?, - ..Default::default() - }) - } - - /// Extract stop strings from request - fn extract_stop_strings(request: &ChatCompletionRequest) -> Vec { - match &request.stop { - Some(StringOrArray::String(s)) => vec![s.clone()], - Some(StringOrArray::Array(arr)) => arr.clone(), - None => vec![], - } - } - - /// Build constraint for structured generation - fn build_constraint_for_chat( - request: &ChatCompletionRequest, - tool_call_constraint: Option<(String, String)>, - ) -> Result, String> { - let mut constraints = Vec::new(); - - // Handle response_format constraints - match &request.response_format { - Some(ResponseFormat::JsonObject) => { - // json_object mode - constrain to valid JSON object - let schema = serde_json::json!({"type": "object"}); - let schema_str = serde_json::to_string(&schema) - .map_err(|e| format!("Failed to serialize JSON schema: {e}"))?; - constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); - } - Some(ResponseFormat::JsonSchema { json_schema }) => { - let schema_str = serde_json::to_string(&json_schema.schema) - .map_err(|e| format!("Failed to serialize JSON schema: {e}"))?; - constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); - } - Some(ResponseFormat::Text) | None => { - // No constraint for text format - } - } - - if let Some(ebnf) = &request.ebnf { - constraints.push(proto::sampling_params::Constraint::EbnfGrammar( - ebnf.clone(), - )); - } - - if let Some(regex) = &request.regex { - constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); - } - - // Handle tool call constraint from preparation stage. - // If response_format already set a constraint, drop the tool constraint - // (matches SGLang HTTP behavior where response_format takes priority). - if let Some((constraint_type, constraint_value)) = tool_call_constraint { - if constraints.is_empty() { - let tool_constraint = match constraint_type.as_str() { - "structural_tag" => { - proto::sampling_params::Constraint::StructuralTag(constraint_value) - } - "json_schema" => { - proto::sampling_params::Constraint::JsonSchema(constraint_value) - } - "ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value), - "regex" => proto::sampling_params::Constraint::Regex(constraint_value), - _ => return Err(format!("Unknown constraint type: {constraint_type}")), - }; - constraints.push(tool_constraint); - } else { - warn!("Constrained decoding is not compatible with tool calls, dropping tool constraint"); - } - } - - match constraints.len() { - 0 => Ok(None), - 1 => Ok(constraints.pop()), - _ => Err("Multiple constraints are not allowed.".to_string()), - } - } - - /// Build gRPC SamplingParams from ResponsesRequest - fn build_grpc_sampling_params_from_responses( - request: &ResponsesRequest, - constraint: Option<(String, String)>, - ) -> Result { - // Used by Harmony models only. Regular models use Chat API path. - // Constraints come from Harmony preparation stage (structural_tag) or tool handling. - - let max_new_tokens = request.max_output_tokens; - - Ok(proto::SamplingParams { - temperature: request.temperature.unwrap_or(1.0), - top_p: request.top_p.unwrap_or(1.0), - top_k: request.top_k, - min_p: request.min_p, - frequency_penalty: request.frequency_penalty.unwrap_or(0.0), - presence_penalty: request.presence_penalty.unwrap_or(0.0), - repetition_penalty: request.repetition_penalty, - max_new_tokens, - stop: vec![], // Does not pass through request.stop yet (follow-up fix) - stop_token_ids: vec![], // Handled by Harmony stop tokens - skip_special_tokens: false, // Keep special tokens for Harmony - spaces_between_special_tokens: true, - ignore_eos: false, - no_stop_trim: false, - n: 1, // Responses API doesn't support n>1 - constraint: Self::build_constraint_for_responses(constraint)?, - ..Default::default() - }) - } - - /// Build constraint for Responses API - /// - /// Handles constraints from Harmony preparation stage (structural_tag for Harmony models, - /// structured output via text field, or tool call constraints). - /// - /// Note: Regular gRPC models use Chat API path with response_format, not this function. - fn build_constraint_for_responses( - constraint: Option<(String, String)>, - ) -> Result, String> { - if let Some((constraint_type, constraint_value)) = constraint { - let parsed_constraint = match constraint_type.as_str() { - "structural_tag" => { - // Harmony models: structural tag from preparation stage - proto::sampling_params::Constraint::StructuralTag(constraint_value) - } - "json_schema" => proto::sampling_params::Constraint::JsonSchema(constraint_value), - "ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value), - "regex" => proto::sampling_params::Constraint::Regex(constraint_value), - _ => return Err(format!("Unknown constraint type: {constraint_type}")), - }; - Ok(Some(parsed_constraint)) - } else { - Ok(None) - } - } - /// Build a GenerateRequest from CreateMessageRequest (Anthropic Messages API) #[expect( clippy::unused_self, @@ -612,8 +446,10 @@ impl SglangSchedulerClient { multimodal_inputs: Option, tool_call_constraint: Option<(String, String)>, ) -> Result { - let sampling_params = - Self::build_grpc_sampling_params_from_messages(body, tool_call_constraint)?; + let sampling_params = crate::sampling_params::build_grpc_sampling_params_from_messages( + body, + tool_call_constraint, + )?; let grpc_request = proto::GenerateRequest { request_id, @@ -634,37 +470,6 @@ impl SglangSchedulerClient { Ok(grpc_request) } - /// Build gRPC SamplingParams from CreateMessageRequest - fn build_grpc_sampling_params_from_messages( - request: &CreateMessageRequest, - tool_call_constraint: Option<(String, String)>, - ) -> Result { - let stop_sequences = request.stop_sequences.clone().unwrap_or_default(); - - // Hardcode to true: gRPC backends return raw token IDs, not decoded text. - let skip_special_tokens = true; - - Ok(proto::SamplingParams { - temperature: request.temperature.unwrap_or(1.0) as f32, - top_p: request.top_p.unwrap_or(1.0) as f32, - top_k: request.top_k.map(|v| v as i32).unwrap_or(-1), - min_p: 0.0, - frequency_penalty: 0.0, - presence_penalty: 0.0, - repetition_penalty: 1.0, - max_new_tokens: Some(request.max_tokens), - stop: stop_sequences, - stop_token_ids: vec![], - skip_special_tokens, - spaces_between_special_tokens: true, - ignore_eos: false, - no_stop_trim: false, - n: 1, - constraint: Self::build_constraint_for_responses(tool_call_constraint)?, - ..Default::default() - }) - } - /// Build a GenerateRequest from CompletionRequest (`/v1/completions`) #[expect( clippy::unused_self, @@ -677,7 +482,8 @@ impl SglangSchedulerClient { original_text: String, token_ids: Vec, ) -> Result { - let sampling_params = Self::build_grpc_sampling_params_from_completion(body)?; + let sampling_params = + crate::sampling_params::build_grpc_sampling_params_from_completion(body)?; let grpc_request = proto::GenerateRequest { request_id, @@ -697,159 +503,6 @@ impl SglangSchedulerClient { Ok(grpc_request) } - - fn build_grpc_sampling_params_from_completion( - request: &CompletionRequest, - ) -> Result { - let stop_sequences = match &request.stop { - Some(StringOrArray::String(s)) => vec![s.clone()], - Some(StringOrArray::Array(arr)) => arr.clone(), - None => vec![], - }; - - let constraint = Self::build_single_constraint_from_completion(request)?; - - Ok(proto::SamplingParams { - temperature: request.temperature.unwrap_or(1.0), - top_p: request.top_p.unwrap_or(1.0), - top_k: request.top_k.unwrap_or(-1), - min_p: request.min_p.unwrap_or(0.0), - frequency_penalty: request.frequency_penalty.unwrap_or(0.0), - presence_penalty: request.presence_penalty.unwrap_or(0.0), - repetition_penalty: request.repetition_penalty.unwrap_or(1.0), - max_new_tokens: request.max_tokens, - min_new_tokens: request.min_tokens.unwrap_or(0), - stop: stop_sequences, - stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), - skip_special_tokens: request.skip_special_tokens, - spaces_between_special_tokens: true, - ignore_eos: request.ignore_eos, - no_stop_trim: request.no_stop_trim, - n: request.n.unwrap_or(1), - constraint, - ..Default::default() - }) - } - - fn build_single_constraint_from_completion( - request: &CompletionRequest, - ) -> Result, String> { - let mut constraints = Vec::new(); - if let Some(json_schema) = &request.json_schema { - constraints.push(proto::sampling_params::Constraint::JsonSchema( - json_schema.clone(), - )); - } - if let Some(regex) = &request.regex { - constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); - } - if let Some(ebnf) = &request.ebnf { - constraints.push(proto::sampling_params::Constraint::EbnfGrammar( - ebnf.clone(), - )); - } - - match constraints.len() { - 0 => Ok(None), - 1 => Ok(constraints.pop()), - _ => Err("Multiple structured constraints are not allowed".to_string()), - } - } - - fn build_single_constraint_from_plain( - params: &GenerateSamplingParams, - ) -> Result, String> { - let mut constraints = Vec::new(); - if let Some(json_schema) = ¶ms.json_schema { - constraints.push(proto::sampling_params::Constraint::JsonSchema( - json_schema.clone(), - )); - } - if let Some(regex) = ¶ms.regex { - constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); - } - if let Some(ebnf) = ¶ms.ebnf { - constraints.push(proto::sampling_params::Constraint::EbnfGrammar( - ebnf.clone(), - )); - } - - match constraints.len() { - 0 => Ok(None), - 1 => Ok(constraints.pop()), - _ => Err("Multiple structured constraints are not allowed".to_string()), - } - } - - fn build_sampling_params_from_plain( - params: Option<&GenerateSamplingParams>, - ) -> Result { - let mut sampling = proto::SamplingParams { - temperature: 1.0, - top_p: 1.0, - top_k: -1, - repetition_penalty: 1.0, - n: 1, - skip_special_tokens: true, - spaces_between_special_tokens: true, - ..Default::default() - }; - - let Some(p) = params else { - return Ok(sampling); - }; - - // Simple field mappings using a macro - macro_rules! map_field { - ($field:ident) => { - if let Some(val) = p.$field { - sampling.$field = val; - } - }; - } - - map_field!(temperature); - map_field!(top_p); - map_field!(top_k); - map_field!(frequency_penalty); - map_field!(presence_penalty); - map_field!(repetition_penalty); - map_field!(min_p); - map_field!(ignore_eos); - map_field!(skip_special_tokens); - map_field!(no_stop_trim); - - // Handle stop sequences - if let Some(stop) = &p.stop { - match stop { - StringOrArray::String(s) => sampling.stop.push(s.clone()), - StringOrArray::Array(arr) => sampling.stop.extend(arr.clone()), - } - } - - // Handle stop token IDs - if let Some(stop_token_ids) = &p.stop_token_ids { - sampling.stop_token_ids.clone_from(stop_token_ids); - } - - // Handle max_new_tokens - sampling.max_new_tokens = p.max_new_tokens; - - // Handle min_new_tokens - if let Some(min_new_tokens) = p.min_new_tokens { - sampling.min_new_tokens = min_new_tokens; - } - - // Handle n - if let Some(n) = p.n { - sampling.n = n; - } - - // Handle constraints (exactly one allowed) - sampling.constraint = Self::build_single_constraint_from_plain(p)?; - - Ok(sampling) - } } // --------------------------------------------------------------------------- @@ -1008,7 +661,7 @@ mod tests { }; let params = - SglangSchedulerClient::build_grpc_sampling_params_from_responses(&request, None) + crate::sampling_params::build_grpc_sampling_params_from_responses(&request, None) .expect("build sampling params"); assert_eq!(params.top_k, 40); @@ -1023,7 +676,7 @@ mod tests { ..Default::default() }; let disabled_params = - SglangSchedulerClient::build_grpc_sampling_params_from_responses(&disabled, None) + crate::sampling_params::build_grpc_sampling_params_from_responses(&disabled, None) .expect("build sampling params"); assert_eq!(disabled_params.top_k, -1); } diff --git a/crates/grpc_client/src/tokenspeed_scheduler.rs b/crates/grpc_client/src/tokenspeed_scheduler.rs new file mode 100644 index 000000000..f82ec5974 --- /dev/null +++ b/crates/grpc_client/src/tokenspeed_scheduler.rs @@ -0,0 +1,546 @@ +//! gRPC client for the TokenSpeed scheduler service. +//! +//! Wire types are TokenSpeed-native end-to-end (`tokenspeed_proto::*`). +//! Sampling params come from the shared `crate::sampling_params` helpers +//! and are field-mapped to TokenSpeed's shape via [`translate::sampling_params`]. +//! The unary RPC adapters (`translate::model_info` / `server_info` / `loads`) +//! still emit the legacy router-side shape pending dedicated +//! `ModelInfo::TokenSpeed` / `ServerInfo::TokenSpeed` arms in the router. + +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + task::{Context, Poll}, + time::Duration, +}; + +use openai_protocol::{ + chat::ChatCompletionRequest, completion::CompletionRequest, generate::GenerateRequest, + messages::CreateMessageRequest, responses::ResponsesRequest, +}; +use tonic::{transport::Channel, Request, Streaming}; +use tracing::{debug, warn}; + +use crate::{sglang_scheduler::proto as sglang, BoxedTraceInjector, NoopTraceInjector}; + +#[expect(clippy::allow_attributes)] +pub mod tokenspeed_proto { + #![allow(clippy::all, clippy::absolute_paths, unused_qualifications)] + tonic::include_proto!("tokenspeed.grpc.scheduler"); +} + +/// Fire-and-forget abort sender invoked from `Drop`. +type AbortDispatcher = Arc; + +/// Auto-aborting wrapper around the TokenSpeed generate stream. +/// Sends Abort on Drop unless `mark_completed` ran first. +pub struct AbortOnDropStream { + inner: Streaming, + request_id: String, + abort_dispatcher: AbortDispatcher, + aborted: Arc, +} + +impl AbortOnDropStream { + pub fn new( + stream: Streaming, + request_id: String, + abort_dispatcher: AbortDispatcher, + ) -> Self { + debug!( + "Created TokenSpeed AbortOnDropStream for request {}", + request_id + ); + Self { + inner: stream, + request_id, + abort_dispatcher, + aborted: Arc::new(AtomicBool::new(false)), + } + } + + pub fn mark_completed(&self) { + self.aborted.store(true, Ordering::Release); + debug!("TokenSpeed request {} marked as completed", self.request_id); + } +} + +impl Drop for AbortOnDropStream { + fn drop(&mut self) { + if self + .aborted + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_err() + { + return; + } + debug!( + "TokenSpeed stream dropped without completion for request {}, sending abort", + self.request_id + ); + (self.abort_dispatcher)(self.request_id.clone()); + } +} + +impl futures::Stream for AbortOnDropStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx) + } +} + +/// gRPC client for the TokenSpeed scheduler. +#[derive(Clone)] +pub struct TokenSpeedSchedulerClient { + client: tokenspeed_proto::token_speed_scheduler_client::TokenSpeedSchedulerClient, + trace_injector: BoxedTraceInjector, +} + +impl TokenSpeedSchedulerClient { + pub async fn connect(endpoint: &str) -> Result> { + Self::connect_with_trace_injector(endpoint, Arc::new(NoopTraceInjector)).await + } + + pub async fn connect_with_trace_injector( + endpoint: &str, + trace_injector: BoxedTraceInjector, + ) -> Result> { + debug!("Connecting to TokenSpeed scheduler at {}", endpoint); + + let http_endpoint = if let Some(addr) = endpoint.strip_prefix("grpc://") { + format!("http://{addr}") + } else { + endpoint.to_string() + }; + + // Channel knobs match the other engine clients. + let channel = Channel::from_shared(http_endpoint)? + .http2_keep_alive_interval(Duration::from_secs(30)) + .keep_alive_timeout(Duration::from_secs(10)) + .keep_alive_while_idle(true) + .tcp_keepalive(Some(Duration::from_secs(60))) + .tcp_nodelay(true) + .http2_adaptive_window(true) + .initial_stream_window_size(Some(16 * 1024 * 1024)) + .initial_connection_window_size(Some(32 * 1024 * 1024)) + .connect() + .await?; + + let client = + tokenspeed_proto::token_speed_scheduler_client::TokenSpeedSchedulerClient::new(channel); + + Ok(Self { + client, + trace_injector, + }) + } + + #[must_use] + pub fn with_trace_injector(mut self, trace_injector: BoxedTraceInjector) -> Self { + self.trace_injector = trace_injector; + self + } + + /// Submit a generation request. + pub async fn generate( + &self, + req: tokenspeed_proto::GenerateRequest, + ) -> Result { + let request_id = req.request_id.clone(); + + let mut client = self.client.clone(); + let mut request = Request::new(req); + + if let Err(e) = self.trace_injector.inject(request.metadata_mut()) { + warn!("Failed to inject trace context: {}", e); + } + + let response = client.generate(request).await?; + + Ok(AbortOnDropStream::new( + response.into_inner(), + request_id, + tokenspeed_abort_dispatcher(self.clone()), + )) + } + + pub async fn health_check(&self) -> Result { + debug!("Sending TokenSpeed health check request"); + let request = Request::new(tokenspeed_proto::HealthCheckRequest {}); + let mut client = self.client.clone(); + let response = client.health_check(request).await?; + let r = response.into_inner(); + Ok(sglang::HealthCheckResponse { + healthy: r.healthy, + message: r.message, + }) + } + + pub async fn abort_request( + &self, + request_id: String, + reason: String, + ) -> Result<(), tonic::Status> { + debug!( + "Sending TokenSpeed abort for {} (reason: {})", + request_id, reason + ); + let request = Request::new(tokenspeed_proto::AbortRequest { + request_id: request_id.clone(), + reason, + }); + let mut client = self.client.clone(); + let response = client.abort(request).await?; + debug!( + "TokenSpeed abort response for {}: success={}, message={}", + request_id, + response.get_ref().success, + response.get_ref().message + ); + Ok(()) + } + + pub async fn get_model_info(&self) -> Result { + let request = Request::new(tokenspeed_proto::GetModelInfoRequest {}); + let mut client = self.client.clone(); + let response = client.get_model_info(request).await?; + Ok(translate::model_info(response.into_inner())) + } + + pub async fn get_server_info(&self) -> Result { + let request = Request::new(tokenspeed_proto::GetServerInfoRequest {}); + let mut client = self.client.clone(); + let response = client.get_server_info(request).await?; + Ok(translate::server_info(response.into_inner())) + } + + pub async fn get_loads( + &self, + include: Vec, + ) -> Result { + let request = Request::new(tokenspeed_proto::GetLoadsRequest { + dp_rank: None, + include, + }); + let mut client = self.client.clone(); + let response = client.get_loads(request).await?; + Ok(translate::loads(response.into_inner())) + } + + // ── Request builders ────────────────────────────────────────────── + + #[expect( + clippy::unused_self, + reason = "receiver kept for API parity with the other engine clients" + )] + pub fn build_generate_request_from_chat( + &self, + request_id: String, + body: &ChatCompletionRequest, + processed_text: String, + token_ids: Vec, + tool_call_constraint: Option<(String, String)>, + ) -> Result { + let sglang_sampling = crate::sampling_params::build_grpc_sampling_params_from_chat( + body, + tool_call_constraint, + )?; + Ok(tokenspeed_proto::GenerateRequest { + request_id, + tokenized: Some(tokenspeed_proto::TokenizedInput { + original_text: processed_text, + input_ids: token_ids, + }), + sampling_params: Some(translate::sampling_params(sglang_sampling)), + return_logprob: body.logprobs, + logprob_start_len: Some(-1), + top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32, + stream: body.stream, + ..Default::default() + }) + } + + #[expect( + clippy::unused_self, + reason = "receiver kept for API parity with the other engine clients" + )] + pub fn build_plain_generate_request( + &self, + request_id: String, + body: &GenerateRequest, + original_text: Option, + token_ids: Vec, + ) -> Result { + let sglang_sampling = crate::sampling_params::build_sampling_params_from_plain( + body.sampling_params.as_ref(), + )?; + Ok(tokenspeed_proto::GenerateRequest { + request_id, + tokenized: Some(tokenspeed_proto::TokenizedInput { + original_text: original_text.unwrap_or_default(), + input_ids: token_ids, + }), + sampling_params: Some(translate::sampling_params(sglang_sampling)), + return_logprob: body.return_logprob.unwrap_or(false), + logprob_start_len: Some(body.logprob_start_len.unwrap_or(-1)), + top_logprobs_num: body.top_logprobs_num.unwrap_or(0), + token_ids_logprob: body.token_ids_logprob.clone().unwrap_or_default(), + stream: body.stream, + }) + } + + #[expect( + clippy::unused_self, + reason = "receiver kept for API parity with the other engine clients" + )] + pub fn build_generate_request_from_responses( + &self, + request_id: String, + body: &ResponsesRequest, + processed_text: String, + token_ids: Vec, + constraint: Option<(String, String)>, + ) -> Result { + let sglang_sampling = + crate::sampling_params::build_grpc_sampling_params_from_responses(body, constraint)?; + Ok(tokenspeed_proto::GenerateRequest { + request_id, + tokenized: Some(tokenspeed_proto::TokenizedInput { + original_text: processed_text, + input_ids: token_ids, + }), + sampling_params: Some(translate::sampling_params(sglang_sampling)), + stream: body.stream.unwrap_or(false), + ..Default::default() + }) + } + + #[expect( + clippy::unused_self, + reason = "receiver kept for API parity with the other engine clients" + )] + pub fn build_generate_request_from_messages( + &self, + request_id: String, + body: &CreateMessageRequest, + processed_text: String, + token_ids: Vec, + tool_call_constraint: Option<(String, String)>, + ) -> Result { + let sglang_sampling = crate::sampling_params::build_grpc_sampling_params_from_messages( + body, + tool_call_constraint, + )?; + Ok(tokenspeed_proto::GenerateRequest { + request_id, + tokenized: Some(tokenspeed_proto::TokenizedInput { + original_text: processed_text, + input_ids: token_ids, + }), + sampling_params: Some(translate::sampling_params(sglang_sampling)), + stream: body.stream.unwrap_or(false), + ..Default::default() + }) + } + + #[expect( + clippy::unused_self, + reason = "receiver kept for API parity with the other engine clients" + )] + pub fn build_generate_request_from_completion( + &self, + request_id: String, + body: &CompletionRequest, + original_text: String, + token_ids: Vec, + ) -> Result { + let sglang_sampling = + crate::sampling_params::build_grpc_sampling_params_from_completion(body)?; + Ok(tokenspeed_proto::GenerateRequest { + request_id, + tokenized: Some(tokenspeed_proto::TokenizedInput { + original_text, + input_ids: token_ids, + }), + sampling_params: Some(translate::sampling_params(sglang_sampling)), + return_logprob: body.logprobs.is_some(), + logprob_start_len: Some(-1), + top_logprobs_num: body.logprobs.unwrap_or(0) as i32, + stream: body.stream, + ..Default::default() + }) + } +} + +/// Spawn a fire-and-forget abort RPC against TokenSpeed when an +/// ``AbortOnDropStream`` is dropped without completion. +fn tokenspeed_abort_dispatcher(client: TokenSpeedSchedulerClient) -> AbortDispatcher { + Arc::new(move |request_id: String| { + let client = client.clone(); + let request_id_for_log = request_id.clone(); + #[expect( + clippy::disallowed_methods, + reason = "fire-and-forget abort on Drop is intentional" + )] + tokio::spawn(async move { + if let Err(e) = client + .abort_request(request_id, "Stream dropped".to_string()) + .await + { + warn!( + "Failed to send TokenSpeed abort on drop for request {}: {}", + request_id_for_log, e + ); + } + }); + }) +} + +// Sampling-params + unary RPC adapters: map between TokenSpeed's wire +// shape and the router-side shape the metadata wrappers consume. +mod translate { + use super::{sglang, tokenspeed_proto as ts}; + + pub(super) fn sampling_params(s: sglang::SamplingParams) -> ts::SamplingParams { + // Source scalars are non-optional with semantic defaults already + // applied; wrap in `Some(_)` so the servicer's `HasField()` checks + // distinguish "set" from "unset" on the wire. + ts::SamplingParams { + temperature: Some(s.temperature), + top_p: Some(s.top_p), + top_k: Some(s.top_k), + min_p: Some(s.min_p), + frequency_penalty: Some(s.frequency_penalty), + presence_penalty: Some(s.presence_penalty), + repetition_penalty: Some(s.repetition_penalty), + max_new_tokens: s.max_new_tokens, + min_new_tokens: s.min_new_tokens, + stop: s.stop, + stop_token_ids: s.stop_token_ids, + ignore_eos: s.ignore_eos, + skip_special_tokens: s.skip_special_tokens, + spaces_between_special_tokens: s.spaces_between_special_tokens, + no_stop_trim: s.no_stop_trim, + n: s.n, + logit_bias: s.logit_bias, + constraint: s.constraint.map(constraint), + custom_params: s.custom_params, + } + } + + fn constraint(c: sglang::sampling_params::Constraint) -> ts::sampling_params::Constraint { + match c { + sglang::sampling_params::Constraint::Regex(r) => { + ts::sampling_params::Constraint::Regex(r) + } + sglang::sampling_params::Constraint::JsonSchema(s) => { + ts::sampling_params::Constraint::JsonSchema(s) + } + sglang::sampling_params::Constraint::EbnfGrammar(g) => { + ts::sampling_params::Constraint::EbnfGrammar(g) + } + sglang::sampling_params::Constraint::StructuralTag(t) => { + ts::sampling_params::Constraint::StructuralTag(t) + } + } + } + + pub(super) fn model_info(r: ts::GetModelInfoResponse) -> sglang::GetModelInfoResponse { + // Surface TokenSpeed's `preferred_sampling_params` JSON in both label + // fields the discovery path may consult, so worker-discovery can + // expose model-published defaults (`temperature`, `top_p`, etc.) to + // the router's default-injection stage. + let preferred = r.preferred_sampling_params; + sglang::GetModelInfoResponse { + model_path: r.model_path, + tokenizer_path: r.tokenizer_path, + is_generation: true, + preferred_sampling_params: preferred.clone(), + weight_version: r.weight_version, + served_model_name: r.served_model_name, + max_context_length: r.max_context_length, + vocab_size: r.vocab_size, + supports_vision: false, + model_type: r.model_type, + eos_token_ids: r.eos_token_ids, + pad_token_id: r.pad_token_id, + bos_token_id: r.bos_token_id, + max_req_input_len: r.max_req_input_len, + architectures: r.architectures, + id2label_json: String::new(), + num_labels: 0, + default_sampling_params_json: preferred, + } + } + + pub(super) fn server_info(r: ts::GetServerInfoResponse) -> sglang::GetServerInfoResponse { + sglang::GetServerInfoResponse { + server_args: r.server_args, + scheduler_info: r.scheduler_info, + active_requests: r.active_requests, + is_paused: r.is_paused, + last_receive_timestamp: 0.0, + uptime_seconds: r.uptime_seconds, + sglang_version: r.tokenspeed_version, + server_type: "grpc".to_string(), + start_time: r.start_time, + max_total_num_tokens: r.max_total_num_tokens, + } + } + + pub(super) fn loads(r: ts::GetLoadsResponse) -> sglang::GetLoadsResponse { + sglang::GetLoadsResponse { + timestamp: r.timestamp, + version: r.version, + dp_rank_count: r.dp_rank_count, + loads: r.loads.into_iter().map(scheduler_load).collect(), + aggregate: r.aggregate.map(aggregate_metrics), + } + } + + fn scheduler_load(s: ts::SchedulerLoad) -> sglang::SchedulerLoad { + sglang::SchedulerLoad { + dp_rank: s.dp_rank, + num_running_reqs: s.num_running_reqs, + num_waiting_reqs: s.num_waiting_reqs, + num_total_reqs: s.num_total_reqs, + num_used_tokens: s.num_used_tokens, + max_total_num_tokens: s.max_total_num_tokens, + token_usage: s.token_usage, + gen_throughput: s.gen_throughput, + cache_hit_rate: s.cache_hit_rate, + utilization: s.utilization, + max_running_requests: s.max_running_requests, + memory: s.memory.map(|m| sglang::MemoryMetrics { + weight_gb: m.weight_gb, + kv_cache_gb: m.kv_cache_gb, + graph_gb: m.graph_gb, + token_capacity: m.token_capacity, + }), + speculative: None, + lora: None, + disaggregation: None, + queues: s.queues.map(|q| sglang::QueueMetrics { + waiting: q.waiting, + grammar: q.grammar, + paused: q.paused, + retracted: q.retracted, + }), + } + } + + fn aggregate_metrics(a: ts::AggregateMetrics) -> sglang::AggregateMetrics { + sglang::AggregateMetrics { + total_running_reqs: a.total_running_reqs, + total_waiting_reqs: a.total_waiting_reqs, + total_reqs: a.total_reqs, + avg_token_usage: a.avg_token_usage, + avg_throughput: a.avg_throughput, + avg_utilization: a.avg_utilization, + } + } +} diff --git a/crates/protocols/src/worker.rs b/crates/protocols/src/worker.rs index 03d6d7547..cd7897eb4 100644 --- a/crates/protocols/src/worker.rs +++ b/crates/protocols/src/worker.rs @@ -197,6 +197,8 @@ pub enum RuntimeType { Trtllm, /// MLX runtime (Apple Silicon). Mlx, + /// TokenSpeed runtime. + TokenSpeed, /// External OpenAI-compatible API (not local inference). External, } @@ -216,6 +218,7 @@ impl std::fmt::Display for RuntimeType { RuntimeType::Vllm => write!(f, "vllm"), RuntimeType::Trtllm => write!(f, "trtllm"), RuntimeType::Mlx => write!(f, "mlx"), + RuntimeType::TokenSpeed => write!(f, "tokenspeed"), RuntimeType::External => write!(f, "external"), } } @@ -235,6 +238,8 @@ impl std::str::FromStr for RuntimeType { Ok(RuntimeType::Trtllm) } else if s.eq_ignore_ascii_case("mlx") { Ok(RuntimeType::Mlx) + } else if s.eq_ignore_ascii_case("tokenspeed") { + Ok(RuntimeType::TokenSpeed) } else if s.eq_ignore_ascii_case("external") { Ok(RuntimeType::External) } else { diff --git a/crates/reasoning_parser/src/factory.rs b/crates/reasoning_parser/src/factory.rs index eeda6cbd1..4eb9e2c18 100644 --- a/crates/reasoning_parser/src/factory.rs +++ b/crates/reasoning_parser/src/factory.rs @@ -9,7 +9,7 @@ use tokio::sync::Mutex; use crate::{ parsers::{ BaseReasoningParser, CohereCmdParser, DeepSeekR1Parser, Glm45Parser, KimiParser, - MiniMaxParser, NanoV3Parser, Qwen3Parser, QwenThinkingParser, Step3Parser, + MiniMaxParser, NanoV3Parser, NoneParser, Qwen3Parser, QwenThinkingParser, Step3Parser, }, traits::{ParserConfig, ReasoningParser, DEFAULT_MAX_BUFFER_SIZE}, }; @@ -176,6 +176,11 @@ impl ParserFactory { Box::new(BaseReasoningParser::new(ParserConfig::default())) }); + // Register no-op parser: returns all text as normal content, + // never produces reasoning_content. Selectable via + // `--reasoning-parser none`. + registry.register_parser("none", || Box::new(NoneParser::new())); + // Register DeepSeek-R1 parser (starts with in_reasoning=true) registry.register_parser("deepseek_r1", || Box::new(DeepSeekR1Parser::new())); diff --git a/crates/reasoning_parser/src/lib.rs b/crates/reasoning_parser/src/lib.rs index 83d49fbd4..97d51d998 100644 --- a/crates/reasoning_parser/src/lib.rs +++ b/crates/reasoning_parser/src/lib.rs @@ -5,7 +5,7 @@ pub mod traits; pub use factory::{ParserFactory, ParserRegistry, PooledParser}; pub use parsers::{ BaseReasoningParser, CohereCmdParser, DeepSeekR1Parser, Glm45Parser, KimiParser, MiniMaxParser, - NanoV3Parser, Qwen3Parser, QwenThinkingParser, Step3Parser, + NanoV3Parser, NoneParser, Qwen3Parser, QwenThinkingParser, Step3Parser, }; pub use traits::{ ParseError, ParserConfig, ParserResult, ReasoningParser, DEFAULT_MAX_BUFFER_SIZE, diff --git a/crates/reasoning_parser/src/parsers/mod.rs b/crates/reasoning_parser/src/parsers/mod.rs index 81757d04c..baa7b795f 100644 --- a/crates/reasoning_parser/src/parsers/mod.rs +++ b/crates/reasoning_parser/src/parsers/mod.rs @@ -5,6 +5,7 @@ pub mod glm45; pub mod kimi; pub mod minimax; pub mod nano_v3; +pub mod none; pub mod qwen3; pub mod step3; @@ -15,5 +16,6 @@ pub use glm45::Glm45Parser; pub use kimi::KimiParser; pub use minimax::MiniMaxParser; pub use nano_v3::NanoV3Parser; +pub use none::NoneParser; pub use qwen3::{Qwen3Parser, QwenThinkingParser}; pub use step3::Step3Parser; diff --git a/crates/reasoning_parser/src/parsers/none.rs b/crates/reasoning_parser/src/parsers/none.rs new file mode 100644 index 000000000..82e9e4363 --- /dev/null +++ b/crates/reasoning_parser/src/parsers/none.rs @@ -0,0 +1,123 @@ +// No-op reasoning parser. +// +// Returns all input text as `normal_text` and never produces reasoning text, +// regardless of whether the input contains ``/`` (or any other) +// markers. Use this when the model emits a single content stream and the +// caller does not want any portion of it separated into `reasoning_content`. + +use crate::traits::{ParseError, ParserResult, ReasoningParser}; + +/// Parser that performs no reasoning extraction. +/// +/// Every byte received is forwarded to `normal_text`; `reasoning_text` is always +/// empty. State is trivial: no buffering, no tokens, no flags. +#[derive(Debug, Clone, Default)] +pub struct NoneParser; + +impl NoneParser { + pub fn new() -> Self { + Self + } +} + +impl ReasoningParser for NoneParser { + fn detect_and_parse_reasoning(&mut self, text: &str) -> Result { + Ok(ParserResult::normal(text.to_string())) + } + + fn parse_reasoning_streaming_incremental( + &mut self, + text: &str, + ) -> Result { + Ok(ParserResult::normal(text.to_string())) + } + + fn reset(&mut self) {} + + fn model_type(&self) -> &str { + "none" + } + + fn is_in_reasoning(&self) -> bool { + false + } + + fn mark_reasoning_started(&mut self) {} + + fn mark_think_start_stripped(&mut self) {} +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn plain_text_goes_to_normal() { + let mut parser = NoneParser::new(); + let result = parser + .detect_and_parse_reasoning("just some content") + .unwrap(); + assert_eq!(result.normal_text, "just some content"); + assert_eq!(result.reasoning_text, ""); + } + + #[test] + fn think_tags_are_kept_in_normal_text() { + let mut parser = NoneParser::new(); + let result = parser + .detect_and_parse_reasoning("cotanswer") + .unwrap(); + assert_eq!(result.normal_text, "cotanswer"); + assert_eq!(result.reasoning_text, ""); + } + + #[test] + fn streaming_passes_chunks_through_unchanged() { + let mut parser = NoneParser::new(); + + let r1 = parser + .parse_reasoning_streaming_incremental("") + .unwrap(); + assert_eq!(r1.normal_text, ""); + assert_eq!(r1.reasoning_text, ""); + + let r2 = parser + .parse_reasoning_streaming_incremental("hidden cot") + .unwrap(); + assert_eq!(r2.normal_text, "hidden cot"); + assert_eq!(r2.reasoning_text, ""); + + let r3 = parser + .parse_reasoning_streaming_incremental("visible") + .unwrap(); + assert_eq!(r3.normal_text, "visible"); + assert_eq!(r3.reasoning_text, ""); + } + + #[test] + fn empty_input_is_normal_and_empty() { + let mut parser = NoneParser::new(); + let result = parser.detect_and_parse_reasoning("").unwrap(); + assert!(result.is_empty()); + } + + #[test] + fn mark_helpers_do_not_change_behavior() { + let mut parser = NoneParser::new(); + parser.mark_reasoning_started(); + parser.mark_think_start_stripped(); + assert!(!parser.is_in_reasoning()); + + let result = parser + .detect_and_parse_reasoning("xy") + .unwrap(); + assert_eq!(result.normal_text, "xy"); + assert_eq!(result.reasoning_text, ""); + } + + #[test] + fn model_type_is_none() { + let parser = NoneParser::new(); + assert_eq!(parser.model_type(), "none"); + } +} diff --git a/e2e_test/chat_completions/test_enable_thinking.py b/e2e_test/chat_completions/test_enable_thinking.py index c555ac8e8..a896bd54a 100644 --- a/e2e_test/chat_completions/test_enable_thinking.py +++ b/e2e_test/chat_completions/test_enable_thinking.py @@ -18,13 +18,13 @@ # ============================================================================= -# Enable Thinking Tests (Qwen 30B) +# Enable Thinking Tests (Qwen3) # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) -@pytest.mark.model("Qwen/Qwen3-30B-A3B") +@pytest.mark.model("Qwen/Qwen3-4B") @pytest.mark.gateway(extra_args=["--reasoning-parser", "qwen3", "--history-backend", "memory"]) @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) @pytest.mark.parametrize("api_client", ["openai", "smg"], indirect=True) diff --git a/e2e_test/chat_completions/test_function_calling.py b/e2e_test/chat_completions/test_function_calling.py index b4df06e4c..83cd1b1c6 100644 --- a/e2e_test/chat_completions/test_function_calling.py +++ b/e2e_test/chat_completions/test_function_calling.py @@ -22,7 +22,9 @@ # Shared Tool Definitions # ============================================================================= -# System message for Llama3.2 function calling +# System message for Llama3.2 function calling — prescribes the +# {"name": ..., "parameters": ...} JSON shape that the ``llama`` tool +# parser looks for. Used by ``TestToolChoiceLlama`` below. LLAMA_SYSTEM_MESSAGE = ( "You are a helpful assistant with tool calling capabilities. " "Only reply with a tool call if the function exists in the library provided by the user. " @@ -100,14 +102,14 @@ # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.2-1B-Instruct") @pytest.mark.gateway(extra_args=["--tool-call-parser", "llama", "--history-backend", "memory"]) @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) @pytest.mark.parametrize("api_client", ["openai", "smg"], indirect=True) class TestOpenAIServerFunctionCalling: - """Tests for OpenAI-compatible function calling with Llama tool parser.""" + """Tests for OpenAI-compatible function calling with the llama tool parser.""" def test_function_calling_format(self, model, api_client): """Test: Whether the function call format returned by the AI is correct. @@ -265,8 +267,8 @@ def test_function_calling_streaming_args_parsing(self, model, api_client): }, "required": ["a", "b"], }, - # Llama-3.2-1B is flaky in tool call. It won't always respond with - # parameters unless we set strict. + # Llama-3.2-1B is flaky in tool call format, so we force it + # with strict mode. "strict": True, }, } @@ -377,7 +379,6 @@ def test_function_call_required(self, model, api_client): - When tool_choice == "required", the model should return one or more tool_calls. """ - tools = [ { "type": "function", @@ -457,7 +458,6 @@ def test_function_call_specific(self, model, api_client): - When tool_choice is a specific ToolChoice, the model should return one or more tool_calls. """ - tools = [ { "type": "function", @@ -526,7 +526,6 @@ def test_streaming_multiple_choices_finish_reason(self, model, api_client): This tests the fix for the bug where only the last index got a finish_reason chunk. """ - tools = [ { "type": "function", @@ -709,7 +708,7 @@ def test_streaming_multiple_choices_without_tools(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.gateway(extra_args=["--tool-call-parser", "pythonic", "--history-backend", "memory"]) @@ -1489,7 +1488,7 @@ def test_conflicting_defs_required_tool_choice(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.2-1B-Instruct") @pytest.mark.gateway(extra_args=["--tool-call-parser", "llama", "--history-backend", "memory"]) @@ -1510,9 +1509,9 @@ class TestToolChoiceLlama(_TestToolChoiceBase): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) -@pytest.mark.model("Qwen/Qwen2.5-7B-Instruct") +@pytest.mark.model("Qwen/Qwen3-4B-Instruct-2507") @pytest.mark.gateway(extra_args=["--tool-call-parser", "qwen", "--history-backend", "memory"]) @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) @pytest.mark.parametrize("api_client", ["openai", "smg"], indirect=True) @@ -1579,9 +1578,9 @@ def test_conflicting_defs_required_tool_choice(self, model, api_client): } -@pytest.mark.engine("sglang", "vllm", "trtllm") -@pytest.mark.gpu(2) -@pytest.mark.model("Qwen/Qwen2.5-14B-Instruct") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") +@pytest.mark.gpu(1) +@pytest.mark.model("Qwen/Qwen3-4B-Instruct-2507") @pytest.mark.gateway(extra_args=["--tool-call-parser", "qwen", "--history-backend", "memory"]) @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) @pytest.mark.parametrize("api_client", ["openai", "smg"], indirect=True) diff --git a/e2e_test/chat_completions/test_openai_server.py b/e2e_test/chat_completions/test_openai_server.py index 517cdbb4a..d05d73287 100644 --- a/e2e_test/chat_completions/test_openai_server.py +++ b/e2e_test/chat_completions/test_openai_server.py @@ -20,7 +20,7 @@ # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -33,7 +33,23 @@ class TestChatCompletion: # Harmony (gpt-oss) does not trim because its detokenization is not channel-aware. STOP_SEQUENCE_TRIMMED = True - @pytest.mark.parametrize("logprobs", [None, 5]) + @pytest.mark.parametrize( + "logprobs", + [ + None, + pytest.param( + 5, + marks=pytest.mark.skip_for_runtime( + "tokenspeed", + reason=( + "tokenspeed's --enable-top-logprobs is not yet implemented " + "(raises at startup); base output logprobs work via " + "--enable-output-logprobs but the test requires top_logprobs=5" + ), + ), + ), + ], + ) @pytest.mark.parametrize("parallel_sample_num", [1, 2]) def test_chat_completion(self, model, api_client, logprobs, parallel_sample_num): """Test non-streaming chat completion with logprobs and parallel sampling.""" @@ -73,7 +89,23 @@ def test_chat_completion(self, model, api_client, logprobs, parallel_sample_num) assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 - @pytest.mark.parametrize("logprobs", [None, 5]) + @pytest.mark.parametrize( + "logprobs", + [ + None, + pytest.param( + 5, + marks=pytest.mark.skip_for_runtime( + "tokenspeed", + reason=( + "tokenspeed's --enable-top-logprobs is not yet implemented " + "(raises at startup); base output logprobs work via " + "--enable-output-logprobs but the test requires top_logprobs=5" + ), + ), + ), + ], + ) @pytest.mark.parametrize("parallel_sample_num", [1, 2]) def test_chat_completion_stream(self, model, api_client, logprobs, parallel_sample_num): """Test streaming chat completion with logprobs and parallel sampling.""" @@ -359,7 +391,7 @@ def _delta_text(delta): return delta.content or getattr(delta, "reasoning_content", "") or "" -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.model("openai/gpt-oss-20b") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -375,13 +407,45 @@ class TestChatCompletionGptOss(TestChatCompletion): STOP_SEQUENCE_TRIMMED = False - @pytest.mark.parametrize("logprobs", [None, 5]) + @pytest.mark.parametrize( + "logprobs", + [ + None, + pytest.param( + 5, + marks=pytest.mark.skip_for_runtime( + "tokenspeed", + reason=( + "tokenspeed's --enable-top-logprobs is not yet implemented " + "(raises at startup); base output logprobs work via " + "--enable-output-logprobs but the test requires top_logprobs=5" + ), + ), + ), + ], + ) @pytest.mark.parametrize("parallel_sample_num", [1, 2]) def test_chat_completion(self, model, api_client, logprobs, parallel_sample_num): """Test non-streaming chat completion with logprobs and parallel sampling.""" super().test_chat_completion(model, api_client, logprobs, parallel_sample_num) - @pytest.mark.parametrize("logprobs", [None, 5]) + @pytest.mark.parametrize( + "logprobs", + [ + None, + pytest.param( + 5, + marks=pytest.mark.skip_for_runtime( + "tokenspeed", + reason=( + "tokenspeed's --enable-top-logprobs is not yet implemented " + "(raises at startup); base output logprobs work via " + "--enable-output-logprobs but the test requires top_logprobs=5" + ), + ), + ), + ], + ) @pytest.mark.parametrize("parallel_sample_num", [1, 2]) @pytest.mark.skip_for_runtime( "trtllm", reason="trtllm may return more top_logprobs than requested in streaming" @@ -402,7 +466,7 @@ def test_response_prefill(self, model, api_client): pass -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(4) @pytest.mark.model("openai/gpt-oss-120b") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) diff --git a/e2e_test/chat_completions/test_structured_output.py b/e2e_test/chat_completions/test_structured_output.py index 654443bbd..2c8fe7235 100644 --- a/e2e_test/chat_completions/test_structured_output.py +++ b/e2e_test/chat_completions/test_structured_output.py @@ -124,7 +124,7 @@ def test_response_format_json_schema_stream(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -139,7 +139,7 @@ class TestStructuredOutputRegular(_TestStructuredOutputBase): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.model("openai/gpt-oss-20b") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -149,7 +149,7 @@ class TestStructuredOutputGptOss(_TestStructuredOutputBase): """Structured output tests for Harmony models (GPT-OSS 20B, 1 GPU).""" -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(4) @pytest.mark.model("openai/gpt-oss-120b") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) diff --git a/e2e_test/chat_completions/test_validation.py b/e2e_test/chat_completions/test_validation.py index 7192f8e08..96e70f371 100644 --- a/e2e_test/chat_completions/test_validation.py +++ b/e2e_test/chat_completions/test_validation.py @@ -37,7 +37,7 @@ def get_tokenizer(model_path: str): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -107,7 +107,7 @@ def test_ignore_eos(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -167,7 +167,7 @@ def run_chat_completion(): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.model("openai/gpt-oss-20b") @pytest.mark.gateway(extra_args=["--history-backend", "memory"]) @@ -243,7 +243,7 @@ def test_tool_choice_with_response_format_rejected(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm") +@pytest.mark.engine("sglang", "vllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.2-1B-Instruct") @pytest.mark.gateway(extra_args=["--tool-call-parser", "llama", "--history-backend", "memory"]) diff --git a/e2e_test/completions/test_basic.py b/e2e_test/completions/test_basic.py index e91d3f90f..865bbcf87 100644 --- a/e2e_test/completions/test_basic.py +++ b/e2e_test/completions/test_basic.py @@ -9,7 +9,7 @@ import pytest -@pytest.mark.engine("sglang", "vllm") +@pytest.mark.engine("sglang", "vllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) @@ -161,7 +161,7 @@ def test_non_streaming_echo_max_tokens_zero(self, model, api_client): assert response.usage.completion_tokens == 0 -@pytest.mark.engine("sglang", "vllm") +@pytest.mark.engine("sglang", "vllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.model("meta-llama/Llama-3.1-8B-Instruct") @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) diff --git a/e2e_test/fixtures/hooks.py b/e2e_test/fixtures/hooks.py index ca15e0027..372fa4bad 100644 --- a/e2e_test/fixtures/hooks.py +++ b/e2e_test/fixtures/hooks.py @@ -82,12 +82,18 @@ def pytest_configure(config: pytest.Config) -> None: def pytest_runtest_setup(item: pytest.Item) -> None: - """Skip tests marked with ``@pytest.mark.skip_for_runtime``.""" - marker = item.get_closest_marker("skip_for_runtime") - if marker: - current_runtime = get_runtime() - skip_runtimes = marker.args - if current_runtime in skip_runtimes: + """Skip tests marked with ``@pytest.mark.skip_for_runtime``. + + A single test item can carry multiple ``skip_for_runtime`` marks — e.g. a + method-level ``@pytest.mark.skip_for_runtime("trtllm", ...)`` plus a + parametrize-attached ``pytest.param(5, marks=skip_for_runtime("tokenspeed", + ...))``. ``get_closest_marker`` only returns one of them, which silently + drops the others. Iterate every mark so a runtime that's named in any of + them gets skipped, regardless of which is "closest". + """ + current_runtime = get_runtime() + for marker in item.iter_markers(name="skip_for_runtime"): + if current_runtime in marker.args: reason = marker.kwargs.get("reason", f"Not supported on {current_runtime}") pytest.skip(f"Skipping for {current_runtime}: {reason}") diff --git a/e2e_test/infra/constants.py b/e2e_test/infra/constants.py index 19d6421f8..9bb7a194c 100644 --- a/e2e_test/infra/constants.py +++ b/e2e_test/infra/constants.py @@ -25,6 +25,7 @@ class Runtime(StrEnum): SGLANG = "sglang" VLLM = "vllm" TRTLLM = "trtllm" + TOKENSPEED = "tokenspeed" OPENAI = "openai" XAI = "xai" GEMINI = "gemini" @@ -33,7 +34,7 @@ class Runtime(StrEnum): # Convenience sets LOCAL_MODES = frozenset({ConnectionMode.HTTP, ConnectionMode.GRPC}) -LOCAL_RUNTIMES = frozenset({Runtime.SGLANG, Runtime.VLLM, Runtime.TRTLLM}) +LOCAL_RUNTIMES = frozenset({Runtime.SGLANG, Runtime.VLLM, Runtime.TRTLLM, Runtime.TOKENSPEED}) CLOUD_RUNTIMES = frozenset({Runtime.OPENAI, Runtime.XAI, Runtime.GEMINI, Runtime.ANTHROPIC}) # Fixture parameter names (used in @pytest.mark.parametrize) @@ -51,7 +52,9 @@ class Runtime(StrEnum): ENV_MODELS = "E2E_MODELS" ENV_BACKENDS = "E2E_BACKENDS" ENV_MODEL = "E2E_MODEL" -ENV_RUNTIME = "E2E_RUNTIME" # Runtime for gRPC tests: "sglang", "vllm", or "trtllm" +ENV_RUNTIME = ( + "E2E_RUNTIME" # Runtime for gRPC tests — one of Runtime.{SGLANG,VLLM,TRTLLM,TOKENSPEED} +) ENV_STARTUP_TIMEOUT = "E2E_STARTUP_TIMEOUT" ENV_SKIP_MODEL_POOL = "SKIP_MODEL_POOL" ENV_SKIP_BACKEND_SETUP = "SKIP_BACKEND_SETUP" @@ -100,11 +103,21 @@ def is_trtllm() -> bool: return get_runtime() == "trtllm" +def is_tokenspeed() -> bool: + """Check if tests are running with TokenSpeed runtime. + + Returns: + True if E2E_RUNTIME is "tokenspeed", False otherwise. + """ + return get_runtime() == "tokenspeed" + + # Runtime display labels RUNTIME_LABELS = { "sglang": "SGLang", "vllm": "vLLM", "trtllm": "TensorRT-LLM", + "tokenspeed": "TokenSpeed", } ENV_SHOW_ROUTER_LOGS = "SHOW_ROUTER_LOGS" diff --git a/e2e_test/infra/model_specs.py b/e2e_test/infra/model_specs.py index 3fd1eb2cd..90d33fd56 100644 --- a/e2e_test/infra/model_specs.py +++ b/e2e_test/infra/model_specs.py @@ -61,7 +61,26 @@ def _resolve_model_path(hf_path: str) -> str: "tp": 1, "features": ["chat", "streaming", "reasoning"], }, - # Thinking/reasoning model (larger) + # Qwen3 instruct (non-thinking variant) — emits the same + # `\n{"name": ..., "arguments": ...}\n` format as + # Qwen 2.5, so the gateway's ``qwen`` tool-call parser applies. Used by + # ``TestToolChoiceQwen`` and ``TestMultiTurnToolCall``: a Qwen3 model is + # required because the Qwen2 family is not in TokenSpeed's model registry. + "Qwen/Qwen3-4B-Instruct-2507": { + "model": _resolve_model_path("Qwen/Qwen3-4B-Instruct-2507"), + "tp": 1, + "features": ["chat", "streaming", "function_calling", "tool_choice"], + }, + # Hybrid Qwen3 with the ``enable_thinking`` chat-template toggle. Used + # by ``TestEnableThinking``. Dense (``Qwen3ForCausalLM``), so it lands + # on tokenspeed's current model registry where the larger + # ``Qwen3-30B-A3B`` (``Qwen3MoeForCausalLM``) does not. Uses the + # existing ``qwen3`` reasoning parser. + "Qwen/Qwen3-4B": { + "model": _resolve_model_path("Qwen/Qwen3-4B"), + "tp": 1, + "features": ["chat", "streaming", "thinking", "reasoning"], + }, "Qwen/Qwen3-30B-A3B": { "model": _resolve_model_path("Qwen/Qwen3-30B-A3B"), "tp": 1, @@ -236,7 +255,7 @@ def get_model_spec(model_id: str) -> dict: DEFAULT_MODEL_PATH = MODEL_SPECS["meta-llama/Llama-3.1-8B-Instruct"]["model"] DEFAULT_SMALL_MODEL_PATH = MODEL_SPECS["meta-llama/Llama-3.2-1B-Instruct"]["model"] DEFAULT_REASONING_MODEL_PATH = MODEL_SPECS["deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"]["model"] -DEFAULT_ENABLE_THINKING_MODEL_PATH = MODEL_SPECS["Qwen/Qwen3-30B-A3B"]["model"] +DEFAULT_ENABLE_THINKING_MODEL_PATH = MODEL_SPECS["Qwen/Qwen3-4B"]["model"] DEFAULT_QWEN_FUNCTION_CALLING_MODEL_PATH = MODEL_SPECS["Qwen/Qwen2.5-7B-Instruct"]["model"] DEFAULT_MISTRAL_FUNCTION_CALLING_MODEL_PATH = MODEL_SPECS["mistralai/Mistral-7B-Instruct-v0.3"][ "model" diff --git a/e2e_test/infra/worker.py b/e2e_test/infra/worker.py index 8e72d4e04..750581714 100644 --- a/e2e_test/infra/worker.py +++ b/e2e_test/infra/worker.py @@ -34,7 +34,7 @@ class Worker: """A single inference worker process.""" model_id: str - engine: str # "sglang", "vllm", or "trtllm" + engine: str # "sglang", "vllm", "trtllm", or "tokenspeed" port: int gpu_ids: list[int] mode: ConnectionMode = ConnectionMode.HTTP @@ -178,6 +178,13 @@ def _build_cmd(self) -> list[str]: return self._build_vllm_http_cmd(model_path, tp_size, spec) elif self.engine == "trtllm": return self._build_trtllm_cmd(model_path, tp_size, spec) + elif self.engine == "tokenspeed": + if self.mode != ConnectionMode.GRPC: + raise ValueError( + "TokenSpeed e2e workers only support gRPC mode; " + "HTTP mode would go through the existing OpenAI frontend." + ) + return self._build_tokenspeed_grpc_cmd(model_path, tp_size, spec) else: raise ValueError(f"Unsupported engine: {self.engine}") @@ -261,6 +268,54 @@ def _build_vllm_base_cmd( cmd.extend(extra) return cmd + def _build_tokenspeed_grpc_cmd(self, model_path: str, tp_size: int, spec: dict) -> list[str]: + """Build TokenSpeed gRPC server command. + + Launches the SMG-hosted TokenSpeed gRPC server + (``smg_grpc_servicer.tokenspeed``) which wraps TokenSpeed's AsyncLLM + behind the dedicated ``tokenspeed.grpc.scheduler`` service. + Auto-detected as TokenSpeed by the Rust router via its native + service-name handshake. + """ + cmd = [ + "python3", + "-m", + "smg_grpc_servicer.tokenspeed", + # Upstream renamed ``--model-path`` to ``--model`` (with the old + # name kept only as a positional alias). Use the new flag form. + "--model", + model_path, + "--host", + DEFAULT_HOST, + "--port", + str(self.port), + "--tensor-parallel-size", + str(tp_size), + "--log-level", + "warning", + # Mirrors what trtllm does and what sglang/vllm do implicitly: + # the smg gateway translates ``tool_choice=required`` and + # ``tool_choice={function}`` into a json_schema constraint on the + # sampling-params proto. TokenSpeed honors that constraint only + # when a grammar backend is configured — its default is ``None``, + # which silently drops the constraint and lets the model free-run. + "--grammar-backend", + "xgrammar", + # Per-token sampled-token logprobs are gated by this flag in + # tokenspeed (``ServerArgs.enable_output_logprobs`` defaults + # OFF). Without it, requests asking for logprobs silently + # receive empty arrays — see test_chat_completion[*-5-*] which + # exercises ``logprobs=True, top_logprobs=5`` and asserts + # logprobs are returned. Top-K logprobs are still missing + # upstream (``--enable-top-logprobs`` is not yet implemented), + # so those parametrize variants stay skipped. + "--enable-output-logprobs", + ] + extra = spec.get("tokenspeed_args", []) + if extra: + cmd.extend(extra) + return cmd + def _build_trtllm_cmd(self, model_path: str, tp_size: int, spec: dict) -> list[str]: """Build TensorRT-LLM gRPC server command.""" # Create config file to enable xgrammar guided decoding diff --git a/e2e_test/responses/test_sampling_params.py b/e2e_test/responses/test_sampling_params.py index 2faf34994..d7fbb7180 100644 --- a/e2e_test/responses/test_sampling_params.py +++ b/e2e_test/responses/test_sampling_params.py @@ -103,7 +103,7 @@ class TestSamplingParamsLocal(_SamplingParamsBase): """Regular model (Qwen via SGLang).""" -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.e2e @pytest.mark.model("openai/gpt-oss-20b") diff --git a/e2e_test/responses/test_state_management.py b/e2e_test/responses/test_state_management.py index ce8679f5d..5ec76e1db 100644 --- a/e2e_test/responses/test_state_management.py +++ b/e2e_test/responses/test_state_management.py @@ -328,7 +328,7 @@ def test_mutually_exclusive_parameters(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.e2e @pytest.mark.model("openai/gpt-oss-20b") diff --git a/e2e_test/responses/test_streaming_events.py b/e2e_test/responses/test_streaming_events.py index be06235de..4c2613a3a 100644 --- a/e2e_test/responses/test_streaming_events.py +++ b/e2e_test/responses/test_streaming_events.py @@ -106,7 +106,7 @@ def test_output_item_event_emitted(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.e2e @pytest.mark.model("openai/gpt-oss-20b") diff --git a/e2e_test/responses/test_structured_output.py b/e2e_test/responses/test_structured_output.py index 359910942..9567f4241 100644 --- a/e2e_test/responses/test_structured_output.py +++ b/e2e_test/responses/test_structured_output.py @@ -115,7 +115,7 @@ def test_structured_output_json_schema(self, model, api_client): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.e2e @pytest.mark.model("openai/gpt-oss-20b") diff --git a/e2e_test/responses/test_tools_call.py b/e2e_test/responses/test_tools_call.py index f4e5bda92..8e906260b 100644 --- a/e2e_test/responses/test_tools_call.py +++ b/e2e_test/responses/test_tools_call.py @@ -763,7 +763,7 @@ def _check_stream(events, expected_label): # ============================================================================= -@pytest.mark.engine("sglang", "vllm", "trtllm") +@pytest.mark.engine("sglang", "vllm", "trtllm", "tokenspeed") @pytest.mark.gpu(2) @pytest.mark.e2e @pytest.mark.model("openai/gpt-oss-20b") diff --git a/e2e_test/router/test_mmlu.py b/e2e_test/router/test_mmlu.py index 2b7937116..1f85e09ff 100644 --- a/e2e_test/router/test_mmlu.py +++ b/e2e_test/router/test_mmlu.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) -@pytest.mark.engine("sglang", "vllm") +@pytest.mark.engine("sglang", "vllm", "tokenspeed") @pytest.mark.gpu(1) @pytest.mark.e2e @pytest.mark.parametrize("setup_backend", ["grpc"], indirect=True) diff --git a/e2e_test/router/test_worker_api.py b/e2e_test/router/test_worker_api.py index 742241ec3..46017e5e2 100644 --- a/e2e_test/router/test_worker_api.py +++ b/e2e_test/router/test_worker_api.py @@ -219,6 +219,11 @@ def test_igw_multiple_workers(self): @pytest.mark.e2e +# TokenSpeed deliberately excluded: this test class spins up its worker +# via ``ConnectionMode.HTTP``, and ``Worker._build_tokenspeed_grpc_cmd`` +# rejects HTTP mode — TokenSpeed has no HTTP frontend in this repo. +# Including ``tokenspeed`` here would fail deterministically on every +# run rather than validate health-check behaviour. @pytest.mark.engine("sglang", "vllm") @pytest.mark.gpu(1) class TestDisableHealthCheck: diff --git a/grpc_servicer/pyproject.toml b/grpc_servicer/pyproject.toml index bea1e10f9..893255b98 100644 --- a/grpc_servicer/pyproject.toml +++ b/grpc_servicer/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "smg-grpc-servicer" version = "0.5.2" -description = "SMG gRPC servicer implementations for LLM inference engines (vLLM, SGLang, MLX)" +description = "SMG gRPC servicer implementations for LLM inference engines (vLLM, SGLang, MLX, TokenSpeed)" requires-python = ">=3.10" dependencies = [ "smg-grpc-proto>=0.4.6", @@ -36,6 +36,23 @@ sglang = ["sglang>=0.5.10"] # without this floor, installing [mlx] against an older proto build would # crash at import time when smg_grpc_servicer.mlx.server runs. mlx = ["smg-grpc-proto>=0.4.7", "mlx>=0.22.0", "mlx-lm>=0.22.0"] +# Note: there is intentionally no ``tokenspeed`` extra. TokenSpeed is not +# published to PyPI; it is installed out-of-tree from the lightseekorg +# checkout via ``scripts/ci_install_tokenspeed.sh`` (CI) or a manual +# ``pip install -e ./tokenspeed/python`` (local dev). An extra named +# ``tokenspeed`` would imply ``pip install smg-grpc-servicer[tokenspeed]`` +# yields a working tokenspeed setup; it does not. +test = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", +] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +markers = [ + "tokenspeed: tests that require TokenSpeed", +] [project.urls] Homepage = "https://github.com/lightseekorg/smg" diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/__init__.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/__init__.py new file mode 100644 index 000000000..d5ced6c52 --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/__init__.py @@ -0,0 +1,11 @@ +"""TokenSpeed gRPC servicer implementation. + +Mirrors smg_grpc_servicer.vllm / smg_grpc_servicer.sglang. Wraps TokenSpeed's +AsyncLLM (main-process async frontend) behind the SGLang gRPC service so the +existing Rust router (which auto-detects the SGLang proto) can route traffic +to TokenSpeed without needing a new client. +""" + +from smg_grpc_servicer.tokenspeed.servicer import TokenSpeedSchedulerServicer + +__all__ = ["TokenSpeedSchedulerServicer"] diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py new file mode 100644 index 000000000..b4e6fb0e6 --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/__main__.py @@ -0,0 +1,71 @@ +"""CLI entrypoint for the TokenSpeed gRPC server. + +Usage:: + + python -m smg_grpc_servicer.tokenspeed --model --host 127.0.0.1 --port 50051 + +All :class:`tokenspeed.runtime.utils.server_args.ServerArgs` flags are accepted +verbatim (we reuse TokenSpeed's own ``prepare_server_args`` so there is no +flag drift between the HTTP and gRPC frontends). +""" + +from __future__ import annotations + +import asyncio +import logging +import sys + +import uvloop +from tokenspeed.runtime.utils.server_args import prepare_server_args + +from smg_grpc_servicer.tokenspeed.server import serve_grpc + + +def main(argv: list[str] | None = None) -> None: + if argv is None: + argv = sys.argv[1:] + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(name)s] %(levelname)s %(message)s", + ) + + # TokenSpeed's ``ServerArgs.resolve_kernel_backends`` defaults + # ``sampling_backend`` to ``"greedy"`` when the user doesn't pass + # ``--sampling-backend``. The greedy backend is argmax-only and + # ignores per-request ``temperature``/``top_p``/``top_k`` — fine for + # the legacy CLI where users opt in to sampling explicitly, but + # disastrous for a gateway-fronted gRPC servicer where per-request + # sampling params arrive on every call. With Llama-3.2-1B the + # always-argmax behavior collapses into single-token loops + # (\\n×N, ' ('×N, "no"×N) within a few hundred steps and + # generation runs to ``max_new_tokens`` — the smg e2e function-calling + # suite makes this directly observable. Force a sampling-respecting + # default unless the operator explicitly chose one. + if not any(a == "--sampling-backend" or a.startswith("--sampling-backend=") for a in argv): + argv = [*argv, "--sampling-backend", "flashinfer"] + + # TokenSpeed's logprob computation is gated by ``--enable-output-logprobs`` + # (default OFF, see ``ServerArgs.enable_output_logprobs``); without the + # flag, requests asking for logprobs receive empty arrays rather than an + # error. The smg gateway's OpenAI-compat path expects per-token logprobs + # whenever ``logprobs=True`` is set, so flip the flag on by default for a + # gateway-fronted gRPC servicer. Operators who want the smaller CUDA-graph + # footprint can pass ``--enable-output-logprobs=False`` explicitly. + # ``--enable-top-logprobs`` is intentionally NOT injected: TokenSpeed + # raises at startup when it's set (the path is not yet implemented). + if not any( + a == "--enable-output-logprobs" or a.startswith("--enable-output-logprobs=") for a in argv + ): + argv = [*argv, "--enable-output-logprobs"] + + server_args = prepare_server_args(argv) + # The scheduler processes will read these env vars; make sure we ran + # through TokenSpeed's shared env/resource setup path instead of + # duplicating it here. + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + asyncio.run(serve_grpc(server_args)) + + +if __name__ == "__main__": + main() diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/health_servicer.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/health_servicer.py new file mode 100644 index 000000000..d6b04a62a --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/health_servicer.py @@ -0,0 +1,130 @@ +"""Standard ``grpc.health.v1.Health`` servicer for the TokenSpeed backend. + +Mirrors ``smg_grpc_servicer.sglang.health_servicer.SGLangHealthServicer`` — +same service-name semantics, same lifecycle (NOT_SERVING → SERVING → NOT_SERVING), +same ``check/watch`` contract — but sources liveness signals from a TokenSpeed +:class:`AsyncLLM` instead of an SGLang ``GrpcRequestManager``. + +The Rust router uses this health check to auto-detect the backend runtime. +TokenSpeed ships its own ``tokenspeed.grpc.scheduler.TokenSpeedScheduler`` +service identity (see ``proto/tokenspeed_scheduler.proto``) so the probe +distinguishes TokenSpeed workers from real SGLang workers regardless of any +wire-level message-type sharing between the two backends. +""" + +from __future__ import annotations + +import logging +import time +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING + +import grpc +from grpc_health.v1 import health_pb2, health_pb2_grpc +from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 + +if TYPE_CHECKING: + from tokenspeed.runtime.engine.async_llm import AsyncLLM + +logger = logging.getLogger(__name__) + +# Seconds of scheduler silence — with pending requests — before we report +# NOT_SERVING. Matches the SGLang equivalent so oncall dashboards are aligned. +STUCK_SCHEDULER_THRESHOLD_SEC = 30.0 + +# Source the advertised service name from the proto descriptor so a future +# ``package`` or ``service`` rename in tokenspeed_scheduler.proto stays in +# sync without a hand-edited string here. +TOKENSPEED_SCHEDULER_SERVICE_NAME = tokenspeed_scheduler_pb2.DESCRIPTOR.services_by_name[ + "TokenSpeedScheduler" +].full_name + + +class TokenSpeedHealthServicer(health_pb2_grpc.HealthServicer): + """Health servicer that tracks TokenSpeed's AsyncLLM liveness. + + Advertises two service levels: + + * ``""`` (empty) — overall server health, flipped to SERVING once the + warmup request succeeds and back to NOT_SERVING on shutdown. + * ``tokenspeed.grpc.scheduler.TokenSpeedScheduler`` — readiness: the + base status, plus a scheduler-responsiveness check (if there are + pending requests but the scheduler hasn't pushed output for >30s, + report NOT_SERVING). + """ + + OVERALL_SERVER = "" + TOKENSPEED_SERVICE = TOKENSPEED_SCHEDULER_SERVICE_NAME + + def __init__(self, async_llm: AsyncLLM, scheduler_info: dict): + self.async_llm = async_llm + self.scheduler_info = scheduler_info + self._serving_status: dict[str, int] = { + self.OVERALL_SERVER: health_pb2.HealthCheckResponse.NOT_SERVING, + self.TOKENSPEED_SERVICE: health_pb2.HealthCheckResponse.NOT_SERVING, + } + logger.info("TokenSpeed gRPC health service initialized") + + def set_serving(self) -> None: + """Flip both services to SERVING (call after successful warmup).""" + self._serving_status[self.OVERALL_SERVER] = health_pb2.HealthCheckResponse.SERVING + self._serving_status[self.TOKENSPEED_SERVICE] = health_pb2.HealthCheckResponse.SERVING + logger.info("TokenSpeed gRPC health status -> SERVING") + + def set_not_serving(self) -> None: + """Flip both services to NOT_SERVING (call on shutdown).""" + self._serving_status[self.OVERALL_SERVER] = health_pb2.HealthCheckResponse.NOT_SERVING + self._serving_status[self.TOKENSPEED_SERVICE] = health_pb2.HealthCheckResponse.NOT_SERVING + logger.info("TokenSpeed gRPC health status -> NOT_SERVING") + + async def Check( + self, + request: health_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> health_pb2.HealthCheckResponse: + service_name = request.service + logger.debug("Health check request for service=%r", service_name) + + if self.async_llm.gracefully_exit: + return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.NOT_SERVING) + + if service_name == self.OVERALL_SERVER: + return health_pb2.HealthCheckResponse( + status=self._serving_status.get( + self.OVERALL_SERVER, health_pb2.HealthCheckResponse.NOT_SERVING + ) + ) + + if service_name == self.TOKENSPEED_SERVICE: + base = self._serving_status.get( + self.TOKENSPEED_SERVICE, health_pb2.HealthCheckResponse.NOT_SERVING + ) + if base != health_pb2.HealthCheckResponse.SERVING: + return health_pb2.HealthCheckResponse(status=base) + + # Scheduler-stuck check: pending work but no recent output. + time_since_last_receive = time.time() - self.async_llm.last_receive_tstamp + pending = len(self.async_llm.rid_to_state) + if time_since_last_receive > STUCK_SCHEDULER_THRESHOLD_SEC and pending > 0: + logger.warning( + "Scheduler appears stuck: %.1fs since last receive, %d pending requests", + time_since_last_receive, + pending, + ) + return health_pb2.HealthCheckResponse( + status=health_pb2.HealthCheckResponse.NOT_SERVING + ) + + return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.SERVING) + + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(f"Unknown service: {service_name}") + return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.SERVICE_UNKNOWN) + + async def Watch( + self, + request: health_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterator[health_pb2.HealthCheckResponse]: + # K8s probes use Check, not Watch — we emit the current status once. + yield await self.Check(request, context) diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/scheduler_launcher.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/scheduler_launcher.py new file mode 100644 index 000000000..64acb18fa --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/scheduler_launcher.py @@ -0,0 +1,60 @@ +"""Scheduler subprocess launcher for the TokenSpeed gRPC server. + +Mirrors ``smg_grpc_servicer.sglang.scheduler_launcher`` but delegates to +TokenSpeed's ``_launch_subprocesses``: we get back a fully-initialised +``AsyncLLM`` along with the scheduler info dict. All scheduler/DP-controller +spawning, multiprocessing start-method, and env priming already live inside +``_launch_subprocesses`` — we only wrap it to return what the gRPC server +cares about and to keep the call site symmetric with the sibling backends. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from tokenspeed.runtime.engine.async_llm import AsyncLLM +from tokenspeed.runtime.entrypoints.engine import _launch_subprocesses +from tokenspeed.runtime.utils.server_args import PortArgs, ServerArgs + +logger = logging.getLogger(__name__) + + +def launch_engine( + server_args: ServerArgs, + port_args: PortArgs | None = None, +) -> tuple[AsyncLLM, dict[str, Any]]: + """Launch TokenSpeed scheduler subprocess(es) and the main-process AsyncLLM. + + Returns: + A tuple ``(async_llm, scheduler_info)``. ``async_llm`` is the live + :class:`AsyncLLM` that the gRPC servicer will drive. ``scheduler_info`` + is the dict rank-0 sent back once its scheduler was ready (contains + e.g. ``max_total_num_tokens``, ``max_req_input_len``, ...). + + Raises: + RuntimeError: If rank-0 scheduler fails to initialize. The original + ``_launch_subprocesses`` surfaces this by re-raising the EOF/assertion + error — we propagate it unchanged. + """ + async_llm, _template_manager, scheduler_info = _launch_subprocesses( + server_args=server_args, + port_args=port_args, + ) + + # Non-zero rank nodes return (None, None, None) from _launch_subprocesses + # and block forever on the dummy health server — they never reach the gRPC + # server. Guard against callers relying on this return on secondary nodes. + if async_llm is None: + raise RuntimeError( + "launch_engine() returned no AsyncLLM. This means the current node " + "is not rank 0 in a multi-node deployment, or the scheduler died " + "during initialization. Only rank 0 may serve gRPC traffic." + ) + + logger.info( + "TokenSpeed engine ready: max_total_num_tokens=%s max_req_input_len=%s", + scheduler_info.get("max_total_num_tokens"), + scheduler_info.get("max_req_input_len"), + ) + return async_llm, scheduler_info diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/server.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/server.py new file mode 100644 index 000000000..bbe67e69a --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/server.py @@ -0,0 +1,195 @@ +"""Standalone TokenSpeed gRPC server — mirrors ``smg_grpc_servicer.sglang.server``.""" + +from __future__ import annotations + +import asyncio +import logging +import os +import signal +import threading +import time +from concurrent import futures + +import grpc +from grpc_health.v1 import health_pb2_grpc +from grpc_reflection.v1alpha import reflection +from smg_grpc_proto import tokenspeed_scheduler_pb2_grpc +from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 +from tokenspeed.runtime.utils.server_args import ServerArgs + +from smg_grpc_servicer.tokenspeed.health_servicer import TokenSpeedHealthServicer +from smg_grpc_servicer.tokenspeed.scheduler_launcher import launch_engine +from smg_grpc_servicer.tokenspeed.servicer import TokenSpeedSchedulerServicer + +logger = logging.getLogger(__name__) + + +async def serve_grpc(server_args: ServerArgs) -> None: + """Run the TokenSpeed gRPC server until a shutdown signal is received.""" + + logger.info("Launching TokenSpeed scheduler + AsyncLLM...") + async_llm, scheduler_info = launch_engine(server_args) + + server = grpc.aio.server( + futures.ThreadPoolExecutor(max_workers=10), + options=[ + ("grpc.max_send_message_length", 1024 * 1024 * 256), + ("grpc.max_receive_message_length", 1024 * 1024 * 256), + # Match SGLang's more-permissive keepalive defaults so long + # prefill stalls don't trip GOAWAY in the Rust client. + ("grpc.http2.min_recv_ping_interval_without_data_ms", 10000), + ("grpc.keepalive_permit_without_calls", True), + ], + ) + + health_servicer = TokenSpeedHealthServicer( + async_llm=async_llm, + scheduler_info=scheduler_info, + ) + health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server) + + servicer = TokenSpeedSchedulerServicer( + async_llm=async_llm, + server_args=server_args, + scheduler_info=scheduler_info, + health_servicer=health_servicer, + ) + tokenspeed_scheduler_pb2_grpc.add_TokenSpeedSchedulerServicer_to_server(servicer, server) + + service_names = ( + tokenspeed_scheduler_pb2.DESCRIPTOR.services_by_name["TokenSpeedScheduler"].full_name, + "grpc.health.v1.Health", + reflection.SERVICE_NAME, + ) + reflection.enable_server_reflection(service_names, server) + + listen_addr = f"{server_args.host}:{server_args.port}" + server.add_insecure_port(listen_addr) + logger.info("TokenSpeed gRPC server listening on %s", listen_addr) + + await server.start() + + # Warmup on a background thread so the async server can handle the probe. + warmup_thread = threading.Thread( + target=_wait_and_warmup, + args=(server_args, health_servicer), + daemon=True, + ) + warmup_thread.start() + + loop = asyncio.get_running_loop() + stop_event = asyncio.Event() + + def _signal_handler() -> None: + logger.info("Received shutdown signal") + stop_event.set() + + for sig in (signal.SIGTERM, signal.SIGINT): + try: + loop.add_signal_handler(sig, _signal_handler) + except NotImplementedError: + # Windows and some exotic envs don't support loop.add_signal_handler. + pass + + try: + await stop_event.wait() + finally: + logger.info("Shutting down TokenSpeed gRPC server") + try: + await servicer.shutdown() + except Exception: # noqa: BLE001 + logger.exception("servicer.shutdown() raised") + await server.stop(5.0) + if warmup_thread.is_alive(): + warmup_thread.join(timeout=5.0) + + +def _wait_and_warmup( + server_args: ServerArgs, + health_servicer: TokenSpeedHealthServicer, +) -> None: + """Probe the gRPC server until it can generate one token, then set SERVING. + + We hit the external port (not the in-process servicer) so the warmup + exercises the same code path a production caller would — including the + gRPC transport, proto codec, and scheduler IPC. + """ + if os.getenv("TOKENSPEED_SKIP_GRPC_WARMUP", "0").lower() in ("1", "true", "yes"): + logger.info("TOKENSPEED_SKIP_GRPC_WARMUP=1 — skipping warmup") + health_servicer.set_serving() + return + + grpc_url = f"{server_args.host}:{server_args.port}" + channel = grpc.insecure_channel( + grpc_url, + options=[ + ("grpc.max_send_message_length", 1024 * 1024 * 256), + ("grpc.max_receive_message_length", 1024 * 1024 * 256), + ], + ) + stub = tokenspeed_scheduler_pb2_grpc.TokenSpeedSchedulerStub(channel) + + # Wait until GetModelInfo round-trips — that's the quickest confirmation + # that the gRPC server is both bound and has a live AsyncLLM behind it. + deadline = time.time() + 180 + connected = False + while time.time() < deadline: + try: + stub.GetModelInfo( + tokenspeed_scheduler_pb2.GetModelInfoRequest(), + timeout=5, + ) + connected = True + break + except Exception as e: # noqa: BLE001 + logger.debug("Warmup: GetModelInfo not ready yet: %s", e) + time.sleep(1) + + if not connected: + logger.error("TokenSpeed gRPC warmup failed: GetModelInfo never succeeded") + channel.close() + return + + # TokenSpeed serves generative LLMs only (the proto has no Embed RPC), so + # the warmup is always a 1-token generate. + warmup_ok = False + try: + warmup = tokenspeed_scheduler_pb2.GenerateRequest( + request_id=f"WARMUP_{time.time()}", + tokenized=tokenspeed_scheduler_pb2.TokenizedInput( + input_ids=[0], + original_text="warmup", + ), + sampling_params=tokenspeed_scheduler_pb2.SamplingParams( + temperature=0.0, + max_new_tokens=1, + ), + stream=False, + ) + final = None + for resp in stub.Generate(warmup, timeout=600): + final = resp + if final is None or not final.HasField("complete"): + logger.warning( + "Warmup Generate returned no Complete frame (last=%r)", + final, + ) + else: + logger.info("Warmup generation succeeded") + warmup_ok = True + except Exception as e: # noqa: BLE001 + logger.warning("TokenSpeed warmup failed: %s", e) + finally: + channel.close() + + # NOT_SERVING keeps the pod out of K8s readiness rotation when warmup + # never produced a Complete frame. + if warmup_ok: + health_servicer.set_serving() + logger.info("TokenSpeed gRPC server is ready to serve") + else: + logger.error( + "TokenSpeed gRPC warmup did not produce a complete frame; " + "health stays NOT_SERVING. K8s readiness will keep this " + "worker out of rotation until manually restarted." + ) diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py new file mode 100644 index 000000000..8d9387f1b --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py @@ -0,0 +1,937 @@ +"""TokenSpeed gRPC servicer. + +Implements the ``tokenspeed.grpc.scheduler.TokenSpeedScheduler`` gRPC service +on top of TokenSpeed's :class:`tokenspeed.runtime.engine.async_llm.AsyncLLM` — +the main-process async frontend that replaced ``TokenizerManager`` in the +AsyncLLM refactor. + +Wire identity & message catalog +------------------------------- +TokenSpeed ships a fully independent proto (``proto/tokenspeed_scheduler.proto``) +with a distinct package, service, and message catalog. The Rust gateway's +``DetectBackendStep`` identifies the worker natively from the service name — +no SGLang-look-alike hack, no runtime marker probe. The proto's field set is +intentionally minimal (top-tier LLM serving only): no Embed, no +GetTokenizer, no SubscribeKvEvents, no multimodal, no PD-disaggregated +serving, no LoRA, no hidden-state forwarding, no classifier outputs. +Anything in that list has to be added to the proto first; it doesn't ride +on a shared SGLang message anymore. +""" + +from __future__ import annotations + +import asyncio +import dataclasses +import json +import logging +import os +import re +import time +from collections.abc import AsyncIterator +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +import grpc +from google.protobuf.struct_pb2 import Struct +from google.protobuf.timestamp_pb2 import Timestamp +from smg_grpc_proto import tokenspeed_scheduler_pb2_grpc +from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 + +from smg_grpc_servicer.tokenspeed.health_servicer import TokenSpeedHealthServicer + +if TYPE_CHECKING: + # Type-only imports — not resolved at module load so the servicer is + # importable in test environments that stub AsyncLLM / ServerArgs. + from tokenspeed.runtime.engine.async_llm import AsyncLLM + from tokenspeed.runtime.utils.server_args import ServerArgs + +logger = logging.getLogger(__name__) + +HEALTH_CHECK_TIMEOUT = int(os.getenv("TOKENSPEED_HEALTH_CHECK_TIMEOUT", "20")) + + +def _lazy_generate_req_input(): + """Late import for ``tokenspeed.runtime.engine.io_struct.GenerateReqInput``. + + Kept lazy so the top of this module loads in test environments that stub + the TokenSpeed engine surface (unit tests don't need a fully-working + TokenSpeed install to exercise proto ↔ request-input conversion). + """ + from tokenspeed.runtime.engine.io_struct import GenerateReqInput + + return GenerateReqInput + + +def _finish_reason_to_dict(reason: Any) -> dict | None: + """Normalise a TokenSpeed finish reason into a dict. + + TokenSpeed emits ``BaseFinishReason``-style objects (or an already- + normalised dict) in ``meta_info["finish_reason"]``; downstream code + expects a dict with at minimum ``{"type": ...}`` and optionally + ``{"matched": int|str}``. ``None`` means "still running". + + We duck-type on ``to_json()`` so the servicer module loads without + pulling in TokenSpeed's full request-processing graph. Unknown shapes + raise ``TypeError`` rather than silently flipping ``length`` / ``abort`` + to ``stop`` — the caller maps that to ``StatusCode.INTERNAL``. + """ + if reason is None or isinstance(reason, dict): + return reason + to_json = getattr(reason, "to_json", None) + if callable(to_json): + result = to_json() + if isinstance(result, dict): + return result + raise TypeError( + f"finish_reason {type(reason).__name__!r}.to_json() returned " + f"{type(result).__name__!r}; expected dict with at least 'type'." + ) + raise TypeError( + f"Unknown finish_reason shape {type(reason).__name__!r}; expected " + f"a dict or an object with a to_json() method." + ) + + +class TokenSpeedSchedulerServicer(tokenspeed_scheduler_pb2_grpc.TokenSpeedSchedulerServicer): + """gRPC servicer exposing TokenSpeed's AsyncLLM over the dedicated TokenSpeed proto.""" + + def __init__( + self, + async_llm: AsyncLLM, + server_args: ServerArgs, + scheduler_info: dict, + health_servicer: TokenSpeedHealthServicer | None = None, + ): + self.async_llm = async_llm + self.server_args = server_args + self.scheduler_info = scheduler_info + self.health_servicer = health_servicer + self.start_time = time.time() + + # Drive AsyncLLM's output-dispatch loop. This is idempotent — the + # first caller creates the handle loop; subsequent callers (including + # the HealthCheck RPC) are no-ops thanks to ``no_create_loop``. + self.async_llm.auto_create_handle_loop() + + logger.info("TokenSpeedSchedulerServicer initialized") + + # ------------------------------------------------------------------ + # Generate (server-streaming) + # ------------------------------------------------------------------ + + async def Generate( + self, + request: tokenspeed_scheduler_pb2.GenerateRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterator[tokenspeed_scheduler_pb2.GenerateResponse]: + rid = request.request_id + logger.info("Generate request %s (stream=%s)", rid, request.stream) + + try: + req_obj = self._build_generate_req(request) + except ValueError as e: + await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + return + + # For n>1, tokenspeed's batch handler generates fresh UUIDs per + # sub-request and tags each streamed dict with a sequential + # ``index`` (see tokenizer_manager.py::_handle_batch_request). + # Non-streaming n>1 yields a *list* of final dicts instead. We + # handle both shapes below. + expanded_rid = getattr(req_obj, "rid", None) + + # When the client sets ``no_stop_trim``, the matched stop token must + # remain in the proto's ``output_ids`` so the gateway-side detokenizer + # can render it (relevant when ``skip_special_tokens=False`` is also + # set). Capture once and thread through the response builders. + no_stop_trim = bool(request.sampling_params.no_stop_trim) + + aborted = False + try: + async for output in self.async_llm.generate_request(req_obj): + # Non-streaming n>1 emits a list of final dicts in one yield. + if isinstance(output, list): + for idx, item in enumerate(output): + item_reason = _finish_reason_to_dict( + item.get("meta_info", {}).get("finish_reason") + ) + if item_reason and item_reason.get("type") == "abort": + code = _abort_status_code(item_reason) + await context.abort(code, item_reason.get("message") or "aborted") + return + ci = int(item.get("index", idx)) + yield self._complete_response( + rid, item, item_reason, ci, no_stop_trim=no_stop_trim + ) + continue + + meta = output.get("meta_info", {}) + reason_dict = _finish_reason_to_dict(meta.get("finish_reason")) + is_finished = reason_dict is not None + + if reason_dict is not None and reason_dict.get("type") == "abort": + code = _abort_status_code(reason_dict) + await context.abort(code, reason_dict.get("message") or "aborted") + return + + choice_index = int(output.get("index", 0)) + + if request.stream: + yield self._chunk_response( + rid, output, reason_dict, choice_index, no_stop_trim=no_stop_trim + ) + if is_finished: + yield self._complete_response( + rid, output, reason_dict, choice_index, no_stop_trim=no_stop_trim + ) + elif is_finished: + yield self._complete_response( + rid, output, reason_dict, choice_index, no_stop_trim=no_stop_trim + ) + + except ValueError as e: + logger.warning("Generate invalid request %s: %s", rid, e) + await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) + except asyncio.CancelledError: + # Client disconnected — sweep every scheduler-side rid we minted + # (including the per-choice ``{rid}-n{i}`` children n>1 creates) + # so abandoned requests don't keep consuming GPU work. + aborted = True + if isinstance(expanded_rid, list): + for r in expanded_rid: + self.async_llm.abort_request(r) + else: + self.async_llm.abort_request(rid) + raise + except grpc.aio.AbortError: + raise + except Exception as e: + logger.exception("Generate failed for request %s", rid) + await context.abort(grpc.StatusCode.INTERNAL, str(e)) + finally: + # Defensive cleanup — the scheduler owns rid_to_state, but if the + # stream was torn down before finish we need to notify it. When + # n>1 we expanded rid to a list of per-choice ids, so walk them. + if not aborted: + rids_to_check = ( + list(expanded_rid) + if isinstance(expanded_rid, list) + else ([expanded_rid] if isinstance(expanded_rid, str) else []) + ) + for r in rids_to_check: + state = self.async_llm.rid_to_state.get(r) + if state is not None and not getattr(state, "finished", False): + self.async_llm.abort_request(r) + + # ------------------------------------------------------------------ + # HealthCheck (unary) + # ------------------------------------------------------------------ + + async def HealthCheck( + self, + request: tokenspeed_scheduler_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.HealthCheckResponse: + """Deep health probe — sends a 1-token generation to the scheduler. + + Mirrors SGLang's contract exactly: if the scheduler pushes *any* + output within ``HEALTH_CHECK_TIMEOUT`` seconds, we consider it alive. + We bypass the normal AsyncLLM lock/metrics by crafting a dedicated + request with ``log_metrics=False`` so health checks don't skew + Prometheus counters. + """ + rid = f"HEALTH_CHECK_{time.time()}" + + if self.async_llm.gracefully_exit: + return tokenspeed_scheduler_pb2.HealthCheckResponse( + healthy=False, message="Server is shutting down" + ) + + # TokenSpeed only serves generative LLMs at this layer (the proto + # has no Embed RPC), so the probe is always a 1-token generate. + GenerateReqInput = _lazy_generate_req_input() + probe = GenerateReqInput( + input_ids=[0], + sampling_params={"max_new_tokens": 1, "temperature": 0.0}, + log_metrics=False, + ) + probe.rid = rid + + tic = time.time() + + async def _drive_probe() -> bool: + try: + async for _ in self.async_llm.generate_request(probe): + return True + except Exception as e: # noqa: BLE001 — the probe is best-effort. + logger.warning("Health probe failed: %s", e) + return False + return False + + task = asyncio.create_task(_drive_probe()) + try: + while time.time() - tic < HEALTH_CHECK_TIMEOUT: + await asyncio.sleep(0.5) + # Any scheduler push after we started counts as healthy. + if self.async_llm.last_receive_tstamp > tic: + return tokenspeed_scheduler_pb2.HealthCheckResponse( + healthy=True, + message="Health check passed", + ) + if task.done(): + return tokenspeed_scheduler_pb2.HealthCheckResponse( + healthy=bool(task.result()), + message=( + "Health check passed" + if task.result() + else "Scheduler returned no output" + ), + ) + finally: + if not task.done(): + task.cancel() + # Best-effort cleanup: the probe rid shouldn't linger. + self.async_llm.abort_request(rid) + + return tokenspeed_scheduler_pb2.HealthCheckResponse( + healthy=False, + message=f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s", + ) + + # ------------------------------------------------------------------ + # Abort (unary) + # ------------------------------------------------------------------ + + async def Abort( + self, + request: tokenspeed_scheduler_pb2.AbortRequest, + _context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.AbortResponse: + """Abort the request + any per-choice expansions from n>1. + + Generate rewrites ``n>1`` requests into a list of rids + ``[{request_id}-n0, {request_id}-n1, ...]`` so TokenSpeed's batch + path sees unique rids. Aborting only the original ``request_id`` + would leave those children running — we sweep them all. + """ + rid = request.request_id + logger.info("Abort request %s", rid) + state_map = self.async_llm.rid_to_state + + # Anchored regex avoids matching unrelated rids like "{rid}-name". + child_pattern = re.compile(rf"^{re.escape(rid)}-n\d+$") + targets = [r for r in state_map if r == rid or child_pattern.match(r)] + + try: + for r in targets: + self.async_llm.abort_request(r) + known = bool(targets) + return tokenspeed_scheduler_pb2.AbortResponse( + success=known, + message=( + f"Aborted {len(targets)} request(s) for {rid}" + if known + else f"Request {rid} not found" + ), + ) + except Exception as e: + logger.exception("Abort failed for %s", rid) + return tokenspeed_scheduler_pb2.AbortResponse(success=False, message=str(e)) + + # ------------------------------------------------------------------ + # GetModelInfo (unary) + # ------------------------------------------------------------------ + + async def GetModelInfo( + self, + _request: tokenspeed_scheduler_pb2.GetModelInfoRequest, + _context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.GetModelInfoResponse: + model_config = self.async_llm.model_config + hf_config = getattr(model_config, "hf_config", None) + + eos = getattr(hf_config, "eos_token_id", None) if hf_config else None + if isinstance(eos, int): + eos_token_ids = [eos] + elif isinstance(eos, list): + eos_token_ids = list(eos) + else: + eos_token_ids = [] + + max_req_input_len = self.scheduler_info.get("max_req_input_len") or ( + self.async_llm.max_req_input_len or 0 + ) + + # TokenSpeed's GetModelInfoResponse intentionally drops + # ``is_generation`` (always true), ``supports_vision`` (always false), + # and ``id2label_json`` / ``num_labels`` (not a classifier serving + # path). The Rust client fills those slots back in when translating + # to its SGLang-shaped wrapper. + # Upstream renamed ``ServerArgs.model_path`` → ``ServerArgs.model`` + # and ``ServerArgs.tokenizer_path`` → ``ServerArgs.tokenizer`` + # alongside the ``--model-path`` → ``--model`` flag rename. Old + # versions still set the ``_path`` form; new ones set the bare + # form. Pick whichever is populated so the servicer works against + # both. + model_path = getattr(self.server_args, "model", None) or getattr( + self.server_args, "model_path", "" + ) + tokenizer_path = getattr(self.server_args, "tokenizer", None) or getattr( + self.server_args, "tokenizer_path", "" + ) + return tokenspeed_scheduler_pb2.GetModelInfoResponse( + model_path=model_path, + tokenizer_path=tokenizer_path or "", + preferred_sampling_params=self.server_args.preferred_sampling_params or "", + weight_version="", + served_model_name=(self.server_args.served_model_name or model_path), + max_context_length=int(self.async_llm.context_len), + vocab_size=int(model_config.vocab_size), + model_type=(getattr(hf_config, "model_type", "") or "") if hf_config else "", + architectures=(getattr(hf_config, "architectures", []) or []) if hf_config else [], + eos_token_ids=eos_token_ids, + pad_token_id=(getattr(hf_config, "pad_token_id", 0) or 0) if hf_config else 0, + bos_token_id=(getattr(hf_config, "bos_token_id", 0) or 0) if hf_config else 0, + max_req_input_len=int(max_req_input_len), + ) + + # ------------------------------------------------------------------ + # GetServerInfo (unary) + # ------------------------------------------------------------------ + + async def GetServerInfo( + self, + _request: tokenspeed_scheduler_pb2.GetServerInfoRequest, + _context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.GetServerInfoResponse: + # TokenSpeed's ``ServerArgs`` is a dataclass, but tests sometimes pass + # a plain namespace. Fall back to ``__dict__`` so both shapes work. + if dataclasses.is_dataclass(self.server_args) and not isinstance(self.server_args, type): + server_args_dict = dataclasses.asdict(self.server_args) + else: + server_args_dict = dict(getattr(self.server_args, "__dict__", {})) + server_args_struct = Struct() + server_args_struct.update(_make_json_serializable(server_args_dict)) + + scheduler_info_struct = Struct() + scheduler_info_struct.update(_make_json_serializable(dict(self.scheduler_info))) + + uptime = time.time() - self.start_time + start_timestamp = Timestamp() + start_timestamp.FromSeconds(int(self.start_time)) + + try: + import tokenspeed # local import: avoid module-load-time dependency + + version = getattr(tokenspeed, "__version__", "unknown") + except Exception: # noqa: BLE001 — fall back gracefully. + version = "unknown" + + return tokenspeed_scheduler_pb2.GetServerInfoResponse( + server_args=server_args_struct, + scheduler_info=scheduler_info_struct, + active_requests=len(self.async_llm.rid_to_state), + is_paused=False, + uptime_seconds=float(uptime), + tokenspeed_version=version, + start_time=start_timestamp, + max_total_num_tokens=int(self.scheduler_info.get("max_total_num_tokens", 0)), + ) + + # ------------------------------------------------------------------ + # GetLoads (unary) — bridges to TokenSpeed's scheduler-side load metrics + # ------------------------------------------------------------------ + + async def GetLoads( + self, + _request: tokenspeed_scheduler_pb2.GetLoadsRequest, + context: grpc.aio.ServicerContext, + ) -> tokenspeed_scheduler_pb2.GetLoadsResponse: + """Return per-DP-rank scheduler load by RPC-ing the scheduler subprocess. + + ``AsyncLLM`` inherits ``SchedulerControlClient.get_load`` which sends + ``GetLoadReqInput`` over the engine_core_client zmq channel and awaits + a ``List[GetLoadReqOutput]`` reply (one per DP rank). Each reply carries + the live counts the scheduler computes in ``event_loop._get_load``: + ``num_reqs`` (running + waiting), ``num_waiting_reqs``, and + ``num_pages`` (KV pages currently in use). We map those to the + ``SchedulerLoad`` proto plus a coarse aggregate so the router-side + consumer matches what it gets from SGLang. + """ + try: + load_outputs = await asyncio.wait_for( + self.async_llm.get_load(), timeout=HEALTH_CHECK_TIMEOUT + ) + except TimeoutError: + await context.abort( + grpc.StatusCode.DEADLINE_EXCEEDED, + f"tokenspeed scheduler did not respond to GetLoad within {HEALTH_CHECK_TIMEOUT}s", + ) + return + except Exception as e: # noqa: BLE001 + logger.exception("GetLoads failed") + await context.abort(grpc.StatusCode.INTERNAL, str(e)) + return + + page_size = int(getattr(self.async_llm.server_args, "page_size", 1) or 1) + # ``max_total_num_tokens`` lives on the scheduler-side ``scheduler_info`` + # dict that ``launch_engine`` plumbed through at boot — not directly on + # AsyncLLM. Fall back to ``server_args.max_total_num_tokens`` (used in + # tests' SimpleNamespace stubs). + max_total_num_tokens = int( + (self.scheduler_info.get("max_total_num_tokens") if self.scheduler_info else None) + or getattr(self.async_llm.server_args, "max_total_num_tokens", 0) + or 0 + ) + + scheduler_loads: list[tokenspeed_scheduler_pb2.SchedulerLoad] = [] + total_running = 0 + total_waiting = 0 + token_usages: list[float] = [] + for lo in load_outputs: + num_running = max(0, int(lo.num_reqs) - int(lo.num_waiting_reqs)) + num_used_tokens = int(lo.num_pages) * page_size + token_usage = ( + num_used_tokens / max_total_num_tokens if max_total_num_tokens > 0 else 0.0 + ) + scheduler_loads.append( + tokenspeed_scheduler_pb2.SchedulerLoad( + dp_rank=int(lo.dp_rank), + num_running_reqs=num_running, + num_waiting_reqs=int(lo.num_waiting_reqs), + num_total_reqs=int(lo.num_reqs), + num_used_tokens=num_used_tokens, + max_total_num_tokens=max_total_num_tokens, + token_usage=token_usage, + ) + ) + total_running += num_running + total_waiting += int(lo.num_waiting_reqs) + token_usages.append(token_usage) + + aggregate = tokenspeed_scheduler_pb2.AggregateMetrics( + total_running_reqs=total_running, + total_waiting_reqs=total_waiting, + total_reqs=total_running + total_waiting, + avg_token_usage=(sum(token_usages) / len(token_usages)) if token_usages else 0.0, + ) + + return tokenspeed_scheduler_pb2.GetLoadsResponse( + timestamp=datetime.now(timezone.utc).isoformat(), + version="tokenspeed", + dp_rank_count=len(scheduler_loads), + loads=scheduler_loads, + aggregate=aggregate, + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + async def shutdown(self, drain_timeout_secs: float = 30.0) -> None: + """Graceful shutdown — drain in-flight requests, then kill scheduler children. + + AsyncLLM's ``sigterm_watchdog`` polls ``gracefully_exit`` every 5s, + drains ``rid_to_state`` and finally calls + ``kill_process_tree(getpid, include_parent=True)``. That works in + steady-state but the gRPC server's main coroutine may unwind before + the watchdog ticks again, in which case the scheduler subprocesses + outlive the parent and end up orphaned. To avoid that, we: + + 1. Flag ``gracefully_exit`` so AsyncLLM stops accepting work and + the watchdog will eventually run its own cleanup. + 2. Wait up to ``drain_timeout_secs`` for ``rid_to_state`` to empty. + 3. Forcibly kill the subprocess tree (``include_parent=False``) so + the scheduler children are reaped regardless of whether the + watchdog tick fires before this coroutine returns. Idempotent + with the watchdog's own ``kill_process_tree`` call. + """ + self.async_llm.gracefully_exit = True + if self.health_servicer: + self.health_servicer.set_not_serving() + + deadline = time.monotonic() + drain_timeout_secs + while time.monotonic() < deadline: + if not getattr(self.async_llm, "rid_to_state", None): + break + await asyncio.sleep(0.5) + else: + logger.warning( + "shutdown drain timed out after %.1fs with %d in-flight requests; " + "killing scheduler children anyway", + drain_timeout_secs, + len(getattr(self.async_llm, "rid_to_state", {}) or {}), + ) + + # Reap the scheduler subprocesses without taking down our own PID; + # server.py's stop sequence still needs us alive to finish gRPC drain. + try: + from tokenspeed.runtime.utils.process import kill_process_tree + except ImportError: + logger.exception( + "Could not import tokenspeed.runtime.utils.process.kill_process_tree; " + "scheduler subprocesses may be orphaned" + ) + return + kill_process_tree(os.getpid(), include_parent=False) + + def _build_generate_req(self, request: tokenspeed_scheduler_pb2.GenerateRequest): + """Translate proto GenerateRequest → TokenSpeed GenerateReqInput. + + Keeps the router's pre-tokenized inputs intact (``input_ids`` set, + ``text`` left blank) so the TokenSpeed InputProcessor skips its own + tokenizer pass. + """ + if not request.HasField("tokenized"): + raise ValueError("GenerateRequest.tokenized is required") + + input_ids = list(request.tokenized.input_ids) + if not input_ids: + raise ValueError("GenerateRequest.tokenized.input_ids is empty") + + sampling = self._sampling_params_from_proto( + request.sampling_params, + reasoning_parser=getattr(self.server_args, "reasoning_parser", None), + ) + + GenerateReqInput = _lazy_generate_req_input() + obj = GenerateReqInput( + input_ids=input_ids, + sampling_params=sampling, + stream=bool(request.stream), + return_logprob=bool(request.return_logprob), + # ``logprob_start_len`` is ``optional int32`` on the wire — use + # presence-tracking, not the proto3 zero-default, to distinguish + # "client omitted" (→ SGLang's ``-1`` = no input logprobs) from + # an explicit ``0`` (→ start input logprobs at position 0). + logprob_start_len=( + request.logprob_start_len if request.HasField("logprob_start_len") else -1 + ), + top_logprobs_num=int(request.top_logprobs_num or 0), + token_ids_logprob=( + list(request.token_ids_logprob) if request.token_ids_logprob else None + ), + # Hidden-state forwarding, multimodal inputs, PD-disaggregated + # serving, LoRA hot-swap and ``log_metrics`` are intentionally + # absent from TokenSpeed's wire — leaving the engine defaults in + # place keeps the call shape simple. + ) + # Older tokenspeed's ``normalize_batch_and_arguments`` treats n>1 as + # batched and asserts ``rid`` is a list in that case. One gRPC + # request carries one rid; expand it to a list of deterministic + # per-choice rids when the caller asked for multiple samples so the + # assert doesn't fire (and the scheduler can still deduplicate). + n = sampling.get("n", 1) or 1 + if n > 1: + obj.rid = [f"{request.request_id}-n{i}" for i in range(n)] + else: + obj.rid = request.request_id + + # NOTE: We deliberately do NOT set ``obj.text`` even when the proto + # carries ``original_text``. TokenSpeed's HTTP serving_chat passes + # ``input_ids=[...], text=None`` to the engine; setting both fields + # has been observed to perturb the engine's input-processor path + # (some validators and normalizers branch on whether text is + # populated). Matching the HTTP shape — ids only, text=None — + # eliminates one source of HTTP-vs-gRPC divergence. + + return obj + + @staticmethod + def _sampling_params_from_proto( + params: tokenspeed_scheduler_pb2.SamplingParams, + *, + reasoning_parser: str | None = None, + ) -> dict[str, Any]: + """Build the dict that ``GenerateReqInput.sampling_params`` expects. + + TokenSpeed's :class:`SamplingParams` consumes this dict via + ``SamplingParams(**obj.sampling_params)``, so field names must match + the Python class (``max_new_tokens``, ``stop``, ``stop_token_ids``, ...). + """ + out: dict[str, Any] = {} + + # All sampling scalars in tokenspeed_scheduler.proto are declared + # ``optional`` (matching ``vllm_engine.proto``). We use + # ``HasField()`` to forward only the values the client explicitly + # set; absent fields fall through to the engine's own + # ``SamplingParams.__init__`` defaults. This eliminates the old + # truthy-check pitfall that silently dropped ``temperature=0`` + # (BFCL's intent for greedy decoding) AND the warmup-default-zero + # crash where invalid ``top_p=0.0`` / ``repetition_penalty=0.0`` + # would reach the engine from internal probe paths. + # + # When ``temperature=0`` does reach the engine (HasField=True for + # an explicitly-sent ``0.0``), the engine + # (``sampling_params.py:104-107``) sets ``top_k=1`` to engage + # greedy decoding. That's the path BFCL relies on. + for _field in ( + "max_new_tokens", + "temperature", + "top_p", + "top_k", + "min_p", + "frequency_penalty", + "presence_penalty", + "repetition_penalty", + ): + if params.HasField(_field): + out[_field] = getattr(params, _field) + + if params.min_new_tokens: + # ``min_new_tokens`` is non-optional; 0 is the "no minimum" sentinel. + out["min_new_tokens"] = params.min_new_tokens + + # Lists + if params.stop: + out["stop"] = list(params.stop) + if params.stop_token_ids: + out["stop_token_ids"] = list(params.stop_token_ids) + + # Bools (always forwarded) + out["skip_special_tokens"] = bool(params.skip_special_tokens) + out["spaces_between_special_tokens"] = bool(params.spaces_between_special_tokens) + out["ignore_eos"] = bool(params.ignore_eos) + # When set, tokenspeed's detokenizer keeps the matched stop token in + # the rendered text (see ``runtime/engine/detokenizer.py``); we also + # suppress the servicer-side ``output_ids`` strip in + # ``_generated_output_ids`` so the EOS reaches the gateway's + # detokenizer when ``skip_special_tokens=False``. + out["no_stop_trim"] = bool(params.no_stop_trim) + + # n (OpenAI-compat, passthrough) + if params.n: + out["n"] = params.n + if params.logit_bias: + out["logit_bias"] = dict(params.logit_bias) + + # Constraint types — exactly one may be set. + if params.HasField("regex"): + out["regex"] = params.regex + elif params.HasField("json_schema"): + # Mirror tokenspeed serving_chat.py: when the engine is + # running with a reasoning parser that has an xgrammar + # template (e.g. ``gpt-oss`` → ``harmony``), wrap the user's + # JSON schema as a structural tag so the grammar only + # activates inside the response channel. Without this, + # xgrammar fights the Harmony channel preamble + # (``<|channel|>analysis<|message|>…``) and the model stalls + # until ``max_tokens``. + wrapped: str | None = None + if reasoning_parser: + try: + from tokenspeed.runtime.grammar.reasoning_structural_tag import ( + structural_tag_for_reasoning_json_schema, + ) + + wrapped = structural_tag_for_reasoning_json_schema( + reasoning_parser, json.loads(params.json_schema) + ) + except ImportError: + wrapped = None + if wrapped is not None: + out["structural_tag"] = wrapped + else: + out["json_schema"] = params.json_schema + elif params.HasField("ebnf_grammar"): + out["ebnf"] = params.ebnf_grammar + elif params.HasField("structural_tag"): + out["structural_tag"] = params.structural_tag + + return out + + def _generated_output_ids( + self, + output: dict, + reason_dict: dict | None, + *, + no_stop_trim: bool = False, + ) -> list[int]: + """Return just the newly-generated tokens from a TokenSpeed output dict. + + TokenSpeed's AsyncLLM has two quirks that the SGLang gRPC proto contract + doesn't expect, both of which break the smg gateway's detokenization + layer and downstream tool-call parsing: + + 1. ``output_ids`` is prefixed with the Llama-3 chat-template assistant + header: ``[<|eot_id|>, <|start_header_id|>, "assistant", + <|end_header_id|>, "\\n\\n", ...generated..., ]``. The + ``skip_special_tokens=True`` detokenization strips the 128xxx + control tokens but keeps the word tokens ``"assistant"`` (78191) + and ``"\\n\\n"`` (271), so the final text looks like + ``assistant\\n\\n{"name": ...}``. The ``llama`` tool parser's + ``serde_json::from_str`` can't handle leading non-JSON prefix and + silently returns zero tool calls. + 2. The trailing stop token (e.g. ``<|eom_id|>`` = 128008) is included + in ``output_ids``; SGLang excludes it. If the gateway ever runs + with ``skip_special_tokens=False`` the stop leaks into the decoded + text and breaks JSON parsing for the same reason. + + Slicing the last ``meta_info.completion_tokens`` tokens gives us the + bare generated sequence that SGLang's ``token_ids`` would carry, and + we then defensively drop any trailing matched stop token. The + per-choice ``matched_stop`` fires in a separate proto field, so no + information is lost. + """ + raw = list(output.get("output_ids") or []) + if not raw: + return raw + completion = output.get("meta_info", {}).get("completion_tokens") + if isinstance(completion, int) and 0 < completion <= len(raw): + token_ids = raw[-completion:] + else: + token_ids = raw + if not no_stop_trim and reason_dict and reason_dict.get("type") == "stop": + matched = reason_dict.get("matched") + if isinstance(matched, int) and token_ids and token_ids[-1] == matched: + token_ids = token_ids[:-1] + return token_ids + + def _chunk_response( + self, + rid: str, + output: dict, + reason_dict: dict | None, + choice_index: int = 0, + *, + no_stop_trim: bool = False, + ) -> tokenspeed_scheduler_pb2.GenerateResponse: + meta = output.get("meta_info", {}) + token_ids = self._generated_output_ids(output, reason_dict, no_stop_trim=no_stop_trim) + return tokenspeed_scheduler_pb2.GenerateResponse( + request_id=rid, + chunk=tokenspeed_scheduler_pb2.GenerateStreamChunk( + token_ids=token_ids, + prompt_tokens=int(meta.get("prompt_tokens", 0)), + completion_tokens=int(meta.get("completion_tokens", len(token_ids))), + cached_tokens=int(meta.get("cached_tokens", 0)), + output_logprobs=self._convert_output_logprobs_to_proto(output, len(token_ids)), + index=choice_index, + ), + ) + + def _complete_response( + self, + rid: str, + output: dict, + reason_dict: dict | None, + choice_index: int = 0, + *, + no_stop_trim: bool = False, + ) -> tokenspeed_scheduler_pb2.GenerateResponse: + meta = output.get("meta_info", {}) + token_ids = self._generated_output_ids(output, reason_dict, no_stop_trim=no_stop_trim) + + finish_reason = "stop" + matched_kwargs: dict[str, Any] = {} + if reason_dict: + kind = reason_dict.get("type") + if kind == "length": + finish_reason = "length" + elif kind == "abort": + finish_reason = "abort" + matched = reason_dict.get("matched") + if isinstance(matched, int): + matched_kwargs["matched_token_id"] = matched + elif isinstance(matched, str): + matched_kwargs["matched_stop_str"] = matched + + return tokenspeed_scheduler_pb2.GenerateResponse( + request_id=rid, + complete=tokenspeed_scheduler_pb2.GenerateComplete( + output_ids=token_ids, + finish_reason=finish_reason, + prompt_tokens=int(meta.get("prompt_tokens", 0)), + completion_tokens=int(meta.get("completion_tokens", len(token_ids))), + cached_tokens=int(meta.get("cached_tokens", 0)), + output_logprobs=self._convert_output_logprobs_to_proto(output, len(token_ids)), + index=choice_index, + **matched_kwargs, + ), + ) + + @staticmethod + def _convert_output_logprobs_to_proto( + output: dict, n_keep: int + ) -> tokenspeed_scheduler_pb2.OutputLogProbs | None: + """Build an ``OutputLogProbs`` proto from a tokenspeed output dict. + + TokenSpeed accumulates the request's logprobs in per-request state + across chunks; ``meta_info["output_token_logprobs"]`` is therefore the + running cumulative list of detokenized + ``(logprob: float, token_id: int, text: Optional[str])`` tuples, and + ``meta_info["output_top_logprobs"]`` is the parallel list of top-K + alternatives per position (each entry is ``None`` or a list of the + same tuple shape). + + We slice the cumulative list down to just **this frame's tokens** by + taking the last ``len(output["output_ids"])`` entries — that's how + many new tokens this frame emitted — and then keep only the first + ``n_keep`` of those, so the alignment matches whatever + ``_generated_output_ids`` returned (it strips a trailing stop token + when the finish reason is ``stop``, leaving the last logprob entry + with no corresponding output id). + + Returns ``None`` when there are no logprobs to emit — either the + client did not request them, or the server was started without + ``--enable-output-logprobs`` (in which case TokenSpeed silently + leaves these meta_info lists empty rather than raising). + """ + if n_keep <= 0: + return None + meta = output.get("meta_info", {}) or {} + raw_token = meta.get("output_token_logprobs") or [] + if not raw_token: + return None + n_chunk = len(output.get("output_ids", []) or []) + if n_chunk <= 0: + return None + + raw_top = meta.get("output_top_logprobs") or [] + chunk_token = raw_token[-n_chunk:] if len(raw_token) >= n_chunk else raw_token + chunk_top = raw_top[-n_chunk:] if len(raw_top) >= n_chunk else raw_top + delta_token = chunk_token[:n_keep] + delta_top = chunk_top[:n_keep] + + top_proto = [] + for entry in delta_top: + if entry: + top_proto.append( + tokenspeed_scheduler_pb2.TopLogProbs( + values=[t[0] for t in entry], + token_ids=[t[1] for t in entry], + ) + ) + else: + # Position with no top-K data (e.g. ``--enable-top-logprobs`` + # is not yet implemented in TokenSpeed; we still emit a + # placeholder per position so the gateway can align indices). + top_proto.append(tokenspeed_scheduler_pb2.TopLogProbs()) + + return tokenspeed_scheduler_pb2.OutputLogProbs( + token_logprobs=[t[0] for t in delta_token], + token_ids=[t[1] for t in delta_token], + top_logprobs=top_proto, + ) + + +def _abort_status_code(reason: dict) -> grpc.StatusCode: + status_code = reason.get("status_code") + if status_code == 400: + return grpc.StatusCode.INVALID_ARGUMENT + if status_code in (408, 504): + return grpc.StatusCode.DEADLINE_EXCEEDED + if status_code == 429: + return grpc.StatusCode.RESOURCE_EXHAUSTED + return grpc.StatusCode.INTERNAL + + +def _make_json_serializable(obj: Any) -> Any: + """Flatten an arbitrary dataclass/config graph into JSON-safe primitives.""" + if obj is None or isinstance(obj, str | int | float | bool): + return obj + if isinstance(obj, list | tuple | set): + return [_make_json_serializable(x) for x in obj] + if isinstance(obj, dict): + return {str(k): _make_json_serializable(v) for k, v in obj.items()} + return str(obj) diff --git a/grpc_servicer/tests/__init__.py b/grpc_servicer/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/grpc_servicer/tests/conftest.py b/grpc_servicer/tests/conftest.py new file mode 100644 index 000000000..3ceadba4f --- /dev/null +++ b/grpc_servicer/tests/conftest.py @@ -0,0 +1,22 @@ +"""Pytest configuration for smg-grpc-servicer unit tests. + +Adds the parent directory to ``sys.path`` so editable installs work +without needing ``pip install -e``, and declares an asyncio-mode default. +""" + +from __future__ import annotations + +import pathlib +import sys + +import pytest + +_HERE = pathlib.Path(__file__).resolve().parent +_PKG_ROOT = _HERE.parent + +if str(_PKG_ROOT) not in sys.path: + sys.path.insert(0, str(_PKG_ROOT)) + + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line("markers", "tokenspeed: tests that require TokenSpeed") diff --git a/grpc_servicer/tests/test_tokenspeed_health_servicer.py b/grpc_servicer/tests/test_tokenspeed_health_servicer.py new file mode 100644 index 000000000..df4856af1 --- /dev/null +++ b/grpc_servicer/tests/test_tokenspeed_health_servicer.py @@ -0,0 +1,98 @@ +"""Unit tests for ``smg_grpc_servicer.tokenspeed.health_servicer``.""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Any +from unittest.mock import MagicMock + +import grpc +import pytest +from grpc_health.v1 import health_pb2 # noqa: E402 +from smg_grpc_servicer.tokenspeed.health_servicer import ( # noqa: E402 + TokenSpeedHealthServicer, +) + + +@dataclass +class FakeEngine: + gracefully_exit: bool = False + last_receive_tstamp: float = 0.0 + rid_to_state: dict[str, Any] = field(default_factory=dict) + + +@pytest.fixture +def servicer() -> TokenSpeedHealthServicer: + return TokenSpeedHealthServicer( + async_llm=FakeEngine(), + scheduler_info={}, + ) + + +@pytest.mark.asyncio +async def test_initial_state_is_not_serving(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + resp = await servicer.Check(health_pb2.HealthCheckRequest(service=""), ctx) + assert resp.status == health_pb2.HealthCheckResponse.NOT_SERVING + + +@pytest.mark.asyncio +async def test_set_serving_flips_both_levels(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + servicer.set_serving() + + # overall + resp = await servicer.Check(health_pb2.HealthCheckRequest(service=""), ctx) + assert resp.status == health_pb2.HealthCheckResponse.SERVING + + # specific + resp = await servicer.Check( + health_pb2.HealthCheckRequest(service=servicer.TOKENSPEED_SERVICE), ctx + ) + assert resp.status == health_pb2.HealthCheckResponse.SERVING + + +@pytest.mark.asyncio +async def test_shutdown_flips_back(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + servicer.set_serving() + servicer.async_llm.gracefully_exit = True + resp = await servicer.Check(health_pb2.HealthCheckRequest(service=""), ctx) + assert resp.status == health_pb2.HealthCheckResponse.NOT_SERVING + + +@pytest.mark.asyncio +async def test_unknown_service_returns_unknown(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + resp = await servicer.Check(health_pb2.HealthCheckRequest(service="bogus.Service"), ctx) + assert resp.status == health_pb2.HealthCheckResponse.SERVICE_UNKNOWN + ctx.set_code.assert_called_once_with(grpc.StatusCode.NOT_FOUND) + + +@pytest.mark.asyncio +async def test_stuck_scheduler_flips_to_not_serving( + servicer: TokenSpeedHealthServicer, +): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + servicer.set_serving() + # Simulate "pending requests, but scheduler hasn't pushed output for 45s" + servicer.async_llm.last_receive_tstamp = time.time() - 45 + servicer.async_llm.rid_to_state["rid-1"] = object() + + resp = await servicer.Check( + health_pb2.HealthCheckRequest(service=servicer.TOKENSPEED_SERVICE), ctx + ) + assert resp.status == health_pb2.HealthCheckResponse.NOT_SERVING + + +@pytest.mark.asyncio +async def test_recent_activity_keeps_serving(servicer: TokenSpeedHealthServicer): + ctx = MagicMock(spec=grpc.aio.ServicerContext) + servicer.set_serving() + servicer.async_llm.last_receive_tstamp = time.time() - 1 + servicer.async_llm.rid_to_state["rid-1"] = object() + resp = await servicer.Check( + health_pb2.HealthCheckRequest(service=servicer.TOKENSPEED_SERVICE), ctx + ) + assert resp.status == health_pb2.HealthCheckResponse.SERVING diff --git a/grpc_servicer/tests/test_tokenspeed_servicer.py b/grpc_servicer/tests/test_tokenspeed_servicer.py new file mode 100644 index 000000000..89ed5c549 --- /dev/null +++ b/grpc_servicer/tests/test_tokenspeed_servicer.py @@ -0,0 +1,1103 @@ +"""Unit tests for ``smg_grpc_servicer.tokenspeed.servicer``. + +Runs against a minimal ``FakeAsyncLLM`` that implements only the AsyncLLM +surface the servicer actually touches. We *do* require TokenSpeed to be +importable (the servicer takes real request classes from ``tokenspeed.*``), +so the whole module is skipped when TokenSpeed is not installed. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable +from dataclasses import dataclass, field +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import grpc +import pytest + +pytest.importorskip( + "smg_grpc_proto", + reason="smg-grpc-proto must be installed to test the servicer", +) + +from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 # noqa: E402 +from smg_grpc_servicer.tokenspeed import servicer as _servicer_module # noqa: E402 +from smg_grpc_servicer.tokenspeed.servicer import ( # noqa: E402 + TokenSpeedSchedulerServicer, + _abort_status_code, + _finish_reason_to_dict, + _make_json_serializable, +) + +# --------------------------------------------------------------------------- +# Stub request class. The servicer lazily imports ``GenerateReqInput`` so +# tests can substitute a minimal local stand-in without pulling in +# TokenSpeed's full scheduler graph. (No ``EmbeddingReqInput`` — the slim +# TokenSpeed proto removed the Embed RPC.) +# --------------------------------------------------------------------------- + + +class _StubReq: + """Minimal stand-in with the attributes the servicer sets on req objects.""" + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + # Allow later attribute assignment for rid / text. + self.rid = None + self.text = None + + +class StubGenerateReqInput(_StubReq): + pass + + +@pytest.fixture(autouse=True) +def _stub_request_inputs(monkeypatch): + """Redirect the servicer's lazy GenerateReqInput import to a local stub.""" + monkeypatch.setattr(_servicer_module, "_lazy_generate_req_input", lambda: StubGenerateReqInput) + yield + + +# --------------------------------------------------------------------------- +# Local fake finish-reason classes. The servicer duck-types on ``.to_json()`` +# so tests don't need to import TokenSpeed's request_types module (which +# pulls in the full scheduler graph and breaks in minimal test envs). +# --------------------------------------------------------------------------- + + +class FINISH_MATCHED_TOKEN: + def __init__(self, matched): + self.matched = matched + + def to_json(self): + return {"type": "stop", "matched": self.matched} + + +class FINISH_MATCHED_STR: + def __init__(self, matched): + self.matched = matched + + def to_json(self): + return {"type": "stop", "matched": self.matched} + + +class FINISH_LENGTH: + def __init__(self, length): + self.length = length + + def to_json(self): + return {"type": "length", "length": self.length} + + +class FINISH_ABORT: + def __init__(self, message="Unknown error"): + self.message = message + + def to_json(self): + return {"type": "abort", "message": self.message} + + +# --------------------------------------------------------------------------- +# FakeAsyncLLM — minimal stand-in for TokenSpeed's AsyncLLM in unit tests. +# --------------------------------------------------------------------------- + + +@dataclass +class _FakeState: + finished: bool = False + + +@dataclass +class FakeAsyncLLM: + """Implements just enough AsyncLLM surface to drive the servicer.""" + + outputs: list[dict] = field(default_factory=list) + is_generation: bool = True + context_len: int = 8192 + max_req_input_len: int | None = 4096 + # Captured state — the servicer mutates/inspects these. + rid_to_state: dict[str, _FakeState] = field(default_factory=dict) + gracefully_exit: bool = False + last_receive_tstamp: float = 0.0 + handle_loop_started: bool = False + aborted_rids: list[str] = field(default_factory=list) + # Override hook: a callable producing outputs per request, used for + # tests that need dynamic yields (e.g. cancel mid-stream). + generate_fn: Callable[[Any], Any] | None = None + + # Default load-fixture: single DP rank, 1 running request, no waiting, + # 100 used pages out of (max_total_num_tokens / page_size). Tests can + # override ``load_outputs`` directly to assert proto-mapping semantics. + load_outputs: list[Any] = field(default_factory=list) + max_total_num_tokens: int = 8192 + + server_args: Any = field( + default_factory=lambda: SimpleNamespace( + model_path="fake-model", + tokenizer_path="fake-model", + served_model_name="fake-model", + preferred_sampling_params=None, + page_size=16, + ) + ) + model_config: Any = field( + default_factory=lambda: SimpleNamespace( + vocab_size=32000, + is_multimodal=False, + hf_config=SimpleNamespace( + eos_token_id=2, + pad_token_id=0, + bos_token_id=1, + model_type="llama", + architectures=["LlamaForCausalLM"], + ), + ) + ) + + def auto_create_handle_loop(self) -> None: + self.handle_loop_started = True + + def abort_request(self, rid: str) -> None: + self.aborted_rids.append(rid) + self.rid_to_state.pop(rid, None) + + async def get_load(self): + # Mirror SchedulerControlClient.get_load — returns the configured + # ``load_outputs`` so tests can drive proto-mapping assertions. + return list(self.load_outputs) + + async def generate_request(self, obj): + # Record the request so tests can assert on what was forwarded. + # ``_build_generate_req`` rewrites ``rid`` to a list of per-choice ids + # when n>1; register state for each so the cancel sweep can abort them + # individually (and so dict assignment doesn't crash on a list key). + rid_attr = getattr(obj, "rid", None) or "no-rid" + rids = list(rid_attr) if isinstance(rid_attr, list) else [rid_attr] + for r in rids: + self.rid_to_state[r] = _FakeState() + if self.generate_fn is not None: + async for out in self.generate_fn(obj): + self.last_receive_tstamp = 9999.0 # anything > tic + yield out + return + for out in self.outputs: + self.last_receive_tstamp = 9999.0 + yield out + for r in rids: + self.rid_to_state[r].finished = True + + +@pytest.fixture +def fake_engine() -> FakeAsyncLLM: + return FakeAsyncLLM() + + +@pytest.fixture +def servicer(fake_engine: FakeAsyncLLM) -> TokenSpeedSchedulerServicer: + return TokenSpeedSchedulerServicer( + async_llm=fake_engine, + server_args=fake_engine.server_args, + scheduler_info={ + "max_total_num_tokens": 100000, + "max_req_input_len": 4096, + }, + ) + + +class _FakeAbortError(grpc.aio.AbortError): + """Stand-in for grpc.aio.AbortError raised by our mock context.abort().""" + + def __init__(self, code: grpc.StatusCode, details: str): + super().__init__() + self.code = code + self.details = details + + def __str__(self) -> str: # makes pytest.raises(match=...) useful + return f"ABORT({self.code.name}, {self.details})" + + +def _make_context() -> MagicMock: + """Build a grpc.aio.ServicerContext whose ``abort()`` raises AbortError. + + Real gRPC servicer contexts raise ``grpc.aio.AbortError`` from + ``context.abort()``. The servicer has a dedicated ``except + grpc.aio.AbortError: raise`` branch to let that propagate cleanly, so + the mock reproduces that behaviour. + """ + ctx = MagicMock(spec=grpc.aio.ServicerContext) + + async def _abort(code, details): + raise _FakeAbortError(code, details) + + ctx.abort = AsyncMock(side_effect=_abort) + ctx.set_code = MagicMock() + ctx.set_details = MagicMock() + return ctx + + +# --------------------------------------------------------------------------- +# Pure-helper tests +# --------------------------------------------------------------------------- + + +class TestFinishReasonToDict: + def test_none(self): + assert _finish_reason_to_dict(None) is None + + def test_length(self): + assert _finish_reason_to_dict(FINISH_LENGTH(length=42)) == { + "type": "length", + "length": 42, + } + + def test_matched_token(self): + assert _finish_reason_to_dict(FINISH_MATCHED_TOKEN(matched=7)) == { + "type": "stop", + "matched": 7, + } + + def test_matched_str(self): + assert _finish_reason_to_dict(FINISH_MATCHED_STR(matched="")) == { + "type": "stop", + "matched": "", + } + + def test_abort(self): + out = _finish_reason_to_dict(FINISH_ABORT(message="boom")) + assert out["type"] == "abort" + assert out["message"] == "boom" + + def test_passthrough_dict(self): + d = {"type": "stop", "matched": "foo"} + assert _finish_reason_to_dict(d) is d + + def test_unknown_raises_typeerror(self): + # Unknown shapes raise TypeError rather than coercing to a fake + # ``stop`` dict: silently flipping length/abort to stop and leaking + # repr() into the user-facing matched_stop_str field would corrupt + # the OpenAI ``finish_reason`` semantics. The Generate handler's + # ``except Exception`` turns the TypeError into INTERNAL. + with pytest.raises(TypeError, match="Unknown finish_reason shape"): + _finish_reason_to_dict("weird") + with pytest.raises(TypeError, match="Unknown finish_reason shape"): + _finish_reason_to_dict(42) + + +class TestAbortStatusCode: + @pytest.mark.parametrize( + "status_code, expected", + [ + (400, grpc.StatusCode.INVALID_ARGUMENT), + (408, grpc.StatusCode.DEADLINE_EXCEEDED), + (504, grpc.StatusCode.DEADLINE_EXCEEDED), + (429, grpc.StatusCode.RESOURCE_EXHAUSTED), + (500, grpc.StatusCode.INTERNAL), + (None, grpc.StatusCode.INTERNAL), + ], + ) + def test_mapping(self, status_code, expected): + assert _abort_status_code({"status_code": status_code}) == expected + + +class TestMakeJsonSerializable: + def test_primitives(self): + assert _make_json_serializable(1) == 1 + assert _make_json_serializable("x") == "x" + assert _make_json_serializable(True) is True + assert _make_json_serializable(None) is None + + def test_list_tuple_set(self): + assert _make_json_serializable([1, "a"]) == [1, "a"] + assert _make_json_serializable((1, 2)) == [1, 2] + assert _make_json_serializable({1, 2, 3}) in ( + [1, 2, 3], + [1, 3, 2], + [2, 1, 3], + [2, 3, 1], + [3, 1, 2], + [3, 2, 1], + ) + + def test_nested_dict(self): + assert _make_json_serializable({"a": [1, {"b": 2}]}) == {"a": [1, {"b": 2}]} + + def test_exotic_types_coerced_to_str(self): + class Foo: + def __str__(self): + return "foo-str" + + assert _make_json_serializable(Foo()) == "foo-str" + + +# --------------------------------------------------------------------------- +# Sampling params conversion +# --------------------------------------------------------------------------- + + +class TestSamplingParamsConversion: + def test_defaults_not_forwarded(self): + params = tokenspeed_scheduler_pb2.SamplingParams() + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params) + # proto3 defaults (0 / False / "") should not end up as TokenSpeed + # overrides — only the always-forwarded bool fields appear. + assert "temperature" not in out + assert "top_p" not in out + assert "top_k" not in out + assert "max_new_tokens" not in out + # always-forwarded bools + assert out["skip_special_tokens"] is False + assert out["spaces_between_special_tokens"] is False + assert out["ignore_eos"] is False + + def test_numeric_fields_forwarded(self): + params = tokenspeed_scheduler_pb2.SamplingParams( + temperature=0.7, + top_p=0.9, + top_k=50, + min_p=0.05, + frequency_penalty=0.1, + presence_penalty=0.2, + repetition_penalty=1.1, + max_new_tokens=128, + min_new_tokens=4, + ) + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params) + assert out["temperature"] == pytest.approx(0.7) + assert out["top_p"] == pytest.approx(0.9) + assert out["top_k"] == 50 + assert out["min_p"] == pytest.approx(0.05) + assert out["frequency_penalty"] == pytest.approx(0.1) + assert out["presence_penalty"] == pytest.approx(0.2) + assert out["repetition_penalty"] == pytest.approx(1.1) + assert out["max_new_tokens"] == 128 + assert out["min_new_tokens"] == 4 + + def test_stop_lists_and_logit_bias(self): + params = tokenspeed_scheduler_pb2.SamplingParams( + stop=["\n\n", ""], + stop_token_ids=[2, 0], + logit_bias={"100": -10.0, "200": 10.0}, + ) + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params) + assert out["stop"] == ["\n\n", ""] + assert out["stop_token_ids"] == [2, 0] + assert out["logit_bias"] == {"100": -10.0, "200": 10.0} + + @pytest.mark.parametrize( + "setter, key, value", + [ + (lambda p: setattr(p, "regex", "a.*"), "regex", "a.*"), + (lambda p: setattr(p, "json_schema", "{}"), "json_schema", "{}"), + (lambda p: setattr(p, "ebnf_grammar", "g"), "ebnf", "g"), + (lambda p: setattr(p, "structural_tag", "tag"), "structural_tag", "tag"), + ], + ) + def test_constraints(self, setter, key, value): + params = tokenspeed_scheduler_pb2.SamplingParams() + setter(params) + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params) + assert out[key] == value + + def test_json_schema_no_reasoning_parser_passes_through(self): + params = tokenspeed_scheduler_pb2.SamplingParams(json_schema='{"type": "object"}') + out = TokenSpeedSchedulerServicer._sampling_params_from_proto(params, reasoning_parser=None) + assert out["json_schema"] == '{"type": "object"}' + assert "structural_tag" not in out + + def test_json_schema_with_reasoning_parser_wraps_as_structural_tag(self, monkeypatch): + import sys + import types + + fake_module = types.ModuleType("tokenspeed.runtime.grammar.reasoning_structural_tag") + captured: dict[str, Any] = {} + + def _fake_wrap(rp: str, schema: Any) -> str: + captured["rp"] = rp + captured["schema"] = schema + return '{"wrapped": "tag"}' + + fake_module.structural_tag_for_reasoning_json_schema = _fake_wrap + monkeypatch.setitem( + sys.modules, + "tokenspeed.runtime.grammar.reasoning_structural_tag", + fake_module, + ) + + params = tokenspeed_scheduler_pb2.SamplingParams(json_schema='{"type": "object"}') + out = TokenSpeedSchedulerServicer._sampling_params_from_proto( + params, reasoning_parser="gpt-oss" + ) + + assert "json_schema" not in out + assert out["structural_tag"] == '{"wrapped": "tag"}' + assert captured["rp"] == "gpt-oss" + assert captured["schema"] == {"type": "object"} + + def test_json_schema_unknown_parser_falls_back_to_raw(self, monkeypatch): + import sys + import types + + fake_module = types.ModuleType("tokenspeed.runtime.grammar.reasoning_structural_tag") + fake_module.structural_tag_for_reasoning_json_schema = lambda rp, s: None + monkeypatch.setitem( + sys.modules, + "tokenspeed.runtime.grammar.reasoning_structural_tag", + fake_module, + ) + + params = tokenspeed_scheduler_pb2.SamplingParams(json_schema='{"type": "object"}') + out = TokenSpeedSchedulerServicer._sampling_params_from_proto( + params, reasoning_parser="unknown-parser" + ) + + assert out["json_schema"] == '{"type": "object"}' + assert "structural_tag" not in out + + +# --------------------------------------------------------------------------- +# Generate RPC +# --------------------------------------------------------------------------- + + +def _make_generate_request( + *, + request_id: str = "rid-1", + input_ids: list[int] | None = None, + stream: bool = False, + max_new_tokens: int = 16, +) -> tokenspeed_scheduler_pb2.GenerateRequest: + return tokenspeed_scheduler_pb2.GenerateRequest( + request_id=request_id, + tokenized=tokenspeed_scheduler_pb2.TokenizedInput( + # Preserve explicit empty-list inputs (for "rejects empty ids" test); + # only fall back to the default if the caller didn't supply any. + input_ids=(input_ids if input_ids is not None else [1, 2, 3, 4]), + original_text="hello", + ), + sampling_params=tokenspeed_scheduler_pb2.SamplingParams( + temperature=0.0, + max_new_tokens=max_new_tokens, + ), + stream=stream, + ) + + +class TestGenerate: + @pytest.mark.asyncio + async def test_non_streaming_emits_complete( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + # TokenSpeed's AsyncLLM includes the trailing matched-stop token in + # ``output_ids`` (and prepends chat-template header tokens — modeled in + # ``test_strips_chat_template_prefix`` below). The servicer normalizes + # these out before the proto goes to the smg gateway so the tool + # parsers see the same tokens they would from the SGLang path. Here we + # check the matched-stop trim: ``raw=[10,11,12]`` with ``matched=12`` + # should arrive as ``[10,11]`` on the wire, and the matched id still + # rides in the ``matched_token_id`` field. + fake_engine.outputs = [ + { + "text": "hi", + "output_ids": [10, 11, 12], + "meta_info": { + "prompt_tokens": 4, + "completion_tokens": 3, + "cached_tokens": 0, + "finish_reason": FINISH_MATCHED_TOKEN(matched=12), + }, + } + ] + ctx = _make_context() + req = _make_generate_request(stream=False) + + frames = [frame async for frame in servicer.Generate(req, ctx)] + assert len(frames) == 1 + frame = frames[0] + assert frame.request_id == "rid-1" + assert frame.HasField("complete") + complete = frame.complete + assert list(complete.output_ids) == [10, 11] + assert complete.finish_reason == "stop" + assert complete.matched_token_id == 12 + assert complete.prompt_tokens == 4 + # Meta's completion_tokens passes through unchanged — matches SGLang's + # ``meta_info.get("completion_tokens")`` convention — even though the + # on-the-wire ``output_ids`` drops the stop token. + assert complete.completion_tokens == 3 + ctx.abort.assert_not_called() + + @pytest.mark.asyncio + async def test_strips_chat_template_prefix( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + """Reproducer for the bug where ``assistant\\n\\n`` leaked into the + decoded text and broke the ``llama`` tool-call parser. + + Real-world capture on Llama-3.2-1B-Instruct with a function-calling + prompt — ``output_ids`` was 27 tokens: 5 chat-template header tokens + (``<|eot_id|>, <|start_header_id|>, "assistant", <|end_header_id|>, + "\\n\\n"``) + 21 generated JSON tokens + 1 ``<|eom_id|>`` stop. With + ``skip_special_tokens=True`` only the 128xxx control tokens get + stripped at detokenization time, so the word token ``"assistant"`` + (78191) and ``"\\n\\n"`` (271) leaked into the text and flipped + ``serde_json::from_str`` from succeeding on clean JSON to failing on + ``assistant\\n\\n{...}``. + + The servicer now slices to the last ``completion_tokens`` tokens so + downstream detokenization only sees the actual generated content. + """ + fake_engine.outputs = [ + { + "text": '{"name": "add", "parameters": {"a": 3, "b": 5}}', + # Shape observed in the wild: [<|eot|>, <|start|>, "assistant", + # <|end|>, "\n\n", ...21 json tokens, <|eom|>] = 27 tokens. + # ``completion_tokens`` in TokenSpeed's meta covers the content + # *plus* the stop token, so 21 + 1 = 22. + "output_ids": [ + 128009, + 128006, + 78191, + 128007, + 271, + *range(9000, 9021), + 128008, + ], + "meta_info": { + "prompt_tokens": 200, + "completion_tokens": 22, + "cached_tokens": 0, + "finish_reason": FINISH_MATCHED_TOKEN(matched=128008), + }, + } + ] + ctx = _make_context() + req = _make_generate_request(stream=False) + + frames = [frame async for frame in servicer.Generate(req, ctx)] + complete = frames[0].complete + # Header tokens dropped via the ``raw[-completion_tokens:]`` slice; + # trailing stop token dropped because ``matched == token_ids[-1]``. + assert list(complete.output_ids) == list(range(9000, 9021)) + assert complete.matched_token_id == 128008 + # meta_info.completion_tokens passes through; only ``output_ids`` is + # normalized. Keeps the tokenspeed servicer's wire contract aligned + # with the SGLang reference. + assert complete.completion_tokens == 22 + + @pytest.mark.asyncio + async def test_streaming_emits_chunks_then_complete( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.outputs = [ + { + "text": "hi", + "output_ids": [10], # delta chunk 1 + "meta_info": { + "prompt_tokens": 4, + "completion_tokens": 1, + "cached_tokens": 0, + "finish_reason": None, + }, + }, + { + "text": "hi there", + "output_ids": [11, 12], # delta chunk 2 + finish + "meta_info": { + "prompt_tokens": 4, + "completion_tokens": 3, + "cached_tokens": 0, + "finish_reason": FINISH_LENGTH(length=16), + }, + }, + ] + ctx = _make_context() + req = _make_generate_request(stream=True) + + frames = [frame async for frame in servicer.Generate(req, ctx)] + # Expect: 2 chunks + 1 complete (emitted alongside the final chunk). + # ``completion_tokens`` here (3) exceeds this chunk's delta length (2), + # so the slice falls back to the raw delta. Length-finish has no + # matched stop to strip either, so token_ids pass through. + assert len(frames) == 3 + assert frames[0].HasField("chunk") + assert list(frames[0].chunk.token_ids) == [10] + assert frames[1].HasField("chunk") + assert list(frames[1].chunk.token_ids) == [11, 12] + assert frames[2].HasField("complete") + assert frames[2].complete.finish_reason == "length" + assert list(frames[2].complete.output_ids) == [11, 12] + + @pytest.mark.asyncio + async def test_empty_input_ids_rejected( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + ctx = _make_context() + req = _make_generate_request(input_ids=[]) + + with pytest.raises(_FakeAbortError) as exc: + async for _ in servicer.Generate(req, ctx): + pass + assert exc.value.code == grpc.StatusCode.INVALID_ARGUMENT + ctx.abort.assert_awaited_once() + + @pytest.mark.asyncio + async def test_abort_finish_reason_surfaces_as_grpc_error( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.outputs = [ + { + "text": "", + "output_ids": [], + "meta_info": { + "prompt_tokens": 0, + "completion_tokens": 0, + "cached_tokens": 0, + "finish_reason": { + "type": "abort", + "message": "client disconnected", + "status_code": 400, + }, + }, + } + ] + ctx = _make_context() + req = _make_generate_request() + + with pytest.raises(_FakeAbortError) as exc: + async for _ in servicer.Generate(req, ctx): + pass + assert exc.value.code == grpc.StatusCode.INVALID_ARGUMENT + + @pytest.mark.asyncio + async def test_cancel_calls_abort_request( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + """Cancelling the Generate task should tell the scheduler to drop the rid.""" + + started = asyncio.Event() + + async def never_finish(_obj): + started.set() + # Block forever so we can cancel from outside. ``yield`` is + # unreachable but keeps this an async generator. + await asyncio.sleep(30) + yield {} # pragma: no cover + + fake_engine.generate_fn = never_finish + ctx = _make_context() + req = _make_generate_request() + + gen = servicer.Generate(req, ctx) + task = asyncio.create_task(_drain(gen)) + await started.wait() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + assert "rid-1" in fake_engine.aborted_rids + + @pytest.mark.asyncio + async def test_cancel_aborts_all_n_children( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + """n>1 expands rid to a list of per-choice ids; cancel must sweep them all. + + _build_generate_req rewrites ``rid`` to ``[rid-n0, rid-n1, ...]`` so + TokenSpeed's batch path sees unique rids per choice. If Generate's + cancel handler aborts only the original rid, the child scheduler + requests keep consuming GPU work. This test guards that edge. + """ + started = asyncio.Event() + + async def never_finish(_obj): + started.set() + await asyncio.sleep(30) + yield {} # pragma: no cover + + fake_engine.generate_fn = never_finish + ctx = _make_context() + req = _make_generate_request() + req.sampling_params.n = 3 + + gen = servicer.Generate(req, ctx) + task = asyncio.create_task(_drain(gen)) + await started.wait() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + # Every per-choice rid must have had abort_request called. + assert set(fake_engine.aborted_rids) >= {"rid-1-n0", "rid-1-n1", "rid-1-n2"} + + +async def _drain(async_gen): + async for _ in async_gen: + pass + + +# --------------------------------------------------------------------------- +# Embed RPC +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Abort / HealthCheck / GetModelInfo / GetServerInfo / GetLoads +# +# Note: TokenSpeed's slim proto removes Embed / GetTokenizer / SubscribeKvEvents +# entirely, so there are no tests for them — the methods aren't on the +# servicer surface. +# --------------------------------------------------------------------------- + + +class TestAbortRpc: + @pytest.mark.asyncio + async def test_abort_known( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.rid_to_state["rid-1"] = _FakeState() + resp = await servicer.Abort( + tokenspeed_scheduler_pb2.AbortRequest(request_id="rid-1"), + _make_context(), + ) + assert resp.success is True + assert "rid-1" in fake_engine.aborted_rids + + @pytest.mark.asyncio + async def test_abort_unknown( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + resp = await servicer.Abort( + tokenspeed_scheduler_pb2.AbortRequest(request_id="missing"), + _make_context(), + ) + assert resp.success is False + # Nothing to abort — no state for "missing" or any "missing-n*" child. + assert fake_engine.aborted_rids == [] + + @pytest.mark.asyncio + async def test_abort_sweeps_n_children( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + """Abort("rid-1") must sweep the per-choice rids Generate mints + when ``sampling_params.n > 1`` (``rid-1-n0``, ``rid-1-n1``, ...). + """ + for child in ("rid-1-n0", "rid-1-n1", "rid-1-n2"): + fake_engine.rid_to_state[child] = _FakeState() + # An unrelated rid the sweep must NOT touch. + fake_engine.rid_to_state["unrelated-rid"] = _FakeState() + + resp = await servicer.Abort( + tokenspeed_scheduler_pb2.AbortRequest(request_id="rid-1"), + _make_context(), + ) + assert resp.success is True + assert sorted(fake_engine.aborted_rids) == [ + "rid-1-n0", + "rid-1-n1", + "rid-1-n2", + ] + + +class TestHealthCheck: + @pytest.mark.asyncio + async def test_reports_shutdown( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.gracefully_exit = True + resp = await servicer.HealthCheck( + tokenspeed_scheduler_pb2.HealthCheckRequest(), _make_context() + ) + assert resp.healthy is False + assert "shutting down" in resp.message.lower() + + @pytest.mark.asyncio + async def test_reports_healthy_when_scheduler_pushes_output( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + # generate_request yields once and updates last_receive_tstamp, which + # is what the health RPC watches for. + fake_engine.outputs = [ + { + "text": "", + "output_ids": [99], + "meta_info": {"finish_reason": FINISH_LENGTH(length=1)}, + } + ] + resp = await servicer.HealthCheck( + tokenspeed_scheduler_pb2.HealthCheckRequest(), _make_context() + ) + assert resp.healthy is True + + +class TestGetModelInfo: + @pytest.mark.asyncio + async def test_basic_fields( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + resp = await servicer.GetModelInfo( + tokenspeed_scheduler_pb2.GetModelInfoRequest(), _make_context() + ) + assert resp.model_path == "fake-model" + assert resp.vocab_size == 32000 + assert resp.max_context_length == 8192 + assert list(resp.eos_token_ids) == [2] + assert resp.model_type == "llama" + assert list(resp.architectures) == ["LlamaForCausalLM"] + + +class TestGetServerInfo: + @pytest.mark.asyncio + async def test_returns_scheduler_info( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + fake_engine.rid_to_state["a"] = _FakeState() + fake_engine.rid_to_state["b"] = _FakeState() + resp = await servicer.GetServerInfo( + tokenspeed_scheduler_pb2.GetServerInfoRequest(), _make_context() + ) + assert resp.active_requests == 2 + assert resp.max_total_num_tokens == 100000 + assert resp.tokenspeed_version + + @pytest.mark.asyncio + async def test_uses_tokenspeed_service_bases(self, servicer: TokenSpeedSchedulerServicer): + """TokenSpeed's servicer inherits the dedicated + ``TokenSpeedSchedulerServicer`` stub — identity is carried by the + proto package/service name, not by a field inside ``server_args``. + Guard the inheritance so nobody reverts to ``SglangSchedulerServicer`` + under the impression that 'wire shape is the same'; the wire shape + is the same, the *service path* is not, and the Rust router routes + on the service path. + """ + from smg_grpc_proto.generated import tokenspeed_scheduler_pb2_grpc + + assert isinstance(servicer, tokenspeed_scheduler_pb2_grpc.TokenSpeedSchedulerServicer) + + +class TestGetLoads: + @pytest.mark.asyncio + async def test_no_dp_ranks_returns_empty( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + # Bridge returns an empty list (e.g. before scheduler boots) — proto + # comes back with 0 ranks but still validly populated for the router. + fake_engine.load_outputs = [] + resp = await servicer.GetLoads(tokenspeed_scheduler_pb2.GetLoadsRequest(), _make_context()) + assert resp.dp_rank_count == 0 + assert resp.version == "tokenspeed" + assert list(resp.loads) == [] + assert resp.aggregate.total_running_reqs == 0 + assert resp.aggregate.total_waiting_reqs == 0 + + @pytest.mark.asyncio + async def test_maps_load_output_fields( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer + ): + # 2 DP ranks. rank 0 has 3 reqs (2 running, 1 waiting) and 100 pages + # used; rank 1 has 1 reqs (1 running, 0 waiting) and 200 pages used. + # page_size=16 (from fake_engine.server_args), max_total_num_tokens=100000 + # (from the servicer fixture's scheduler_info). + fake_engine.load_outputs = [ + SimpleNamespace(dp_rank=0, num_reqs=3, num_waiting_reqs=1, num_pages=100), + SimpleNamespace(dp_rank=1, num_reqs=1, num_waiting_reqs=0, num_pages=200), + ] + resp = await servicer.GetLoads(tokenspeed_scheduler_pb2.GetLoadsRequest(), _make_context()) + assert resp.dp_rank_count == 2 + assert len(resp.loads) == 2 + # rank 0 + l0 = resp.loads[0] + assert l0.dp_rank == 0 + assert l0.num_running_reqs == 2 # num_reqs - num_waiting_reqs + assert l0.num_waiting_reqs == 1 + assert l0.num_total_reqs == 3 + assert l0.num_used_tokens == 100 * 16 # pages * page_size + assert l0.max_total_num_tokens == 100000 + assert l0.token_usage == pytest.approx(100 * 16 / 100000) + # rank 1 + l1 = resp.loads[1] + assert l1.dp_rank == 1 + assert l1.num_running_reqs == 1 + assert l1.num_used_tokens == 200 * 16 + # aggregate + assert resp.aggregate.total_running_reqs == 3 + assert resp.aggregate.total_waiting_reqs == 1 + assert resp.aggregate.total_reqs == 4 + assert resp.aggregate.avg_token_usage == pytest.approx( + (100 * 16 / 100000 + 200 * 16 / 100000) / 2 + ) + + @pytest.mark.asyncio + async def test_scheduler_timeout_aborts_with_deadline_exceeded( + self, fake_engine: FakeAsyncLLM, servicer: TokenSpeedSchedulerServicer, monkeypatch + ): + # If the scheduler subprocess never replies, the bridge call hangs. + # The servicer wraps it in ``asyncio.wait_for`` and aborts with + # DEADLINE_EXCEEDED rather than blocking the gRPC call indefinitely. + async def _hang(): + await asyncio.sleep(60) + return [] + + fake_engine.get_load = _hang # type: ignore[method-assign] + monkeypatch.setattr(_servicer_module, "HEALTH_CHECK_TIMEOUT", 0.05) + ctx = _make_context() + with pytest.raises(_FakeAbortError) as exc: + await servicer.GetLoads(tokenspeed_scheduler_pb2.GetLoadsRequest(), ctx) + assert exc.value.code == grpc.StatusCode.DEADLINE_EXCEEDED + + +# --------------------------------------------------------------------------- +# _build_generate_req semantics (pre-tokenized input) +# --------------------------------------------------------------------------- + + +class TestBuildGenerateReq: + def test_preserves_input_ids(self, servicer: TokenSpeedSchedulerServicer): + req = _make_generate_request(input_ids=[11, 22, 33], stream=True) + obj = servicer._build_generate_req(req) + assert obj.input_ids == [11, 22, 33] + assert obj.rid == "rid-1" + assert obj.stream is True + assert obj.sampling_params["max_new_tokens"] == 16 + + def test_rejects_missing_tokenized(self, servicer: TokenSpeedSchedulerServicer): + req = tokenspeed_scheduler_pb2.GenerateRequest(request_id="x") + with pytest.raises(ValueError, match="tokenized"): + servicer._build_generate_req(req) + + +# --------------------------------------------------------------------------- +# Output logprobs proto conversion +# --------------------------------------------------------------------------- + + +class TestConvertOutputLogprobsToProto: + """``_convert_output_logprobs_to_proto`` reads the cumulative + ``meta_info["output_token_logprobs"]`` / ``output_top_logprobs`` lists + that TokenSpeed accumulates per request, slices the last + ``len(output_ids)`` entries (the tokens this frame emitted), and keeps + the first ``n_keep`` so the result aligns with whatever + ``_generated_output_ids`` returned (which may have stripped a trailing + stop token).""" + + def test_returns_none_when_logprobs_empty(self): + # ``--enable-output-logprobs`` not set on the server → the keys exist + # in meta_info but the lists are empty. Must not return a half-built + # proto in this case (gateway would treat empty as "logprobs missing"). + out = { + "output_ids": [10, 20, 30], + "meta_info": {"output_token_logprobs": [], "output_top_logprobs": []}, + } + assert TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=3) is None + + def test_returns_none_when_keys_missing(self): + # Logprobs not requested at all → meta_info lacks the keys entirely. + out = {"output_ids": [10, 20, 30], "meta_info": {}} + assert TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=3) is None + + def test_returns_none_when_n_keep_zero(self): + # Stop-token strip can leave n_keep == 0 for a 1-token frame whose + # only token was the stop. Don't emit a proto with a length mismatch. + out = { + "output_ids": [99], + "meta_info": { + "output_token_logprobs": [(-0.1, 99, None)], + "output_top_logprobs": [None], + }, + } + assert TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=0) is None + + def test_non_streaming_full_output(self): + # Non-streaming: output_ids covers the entire generation; cumulative + # meta_info matches it exactly. n_keep == len(output_ids) → emit all. + out = { + "output_ids": [10, 20, 30], + "meta_info": { + "output_token_logprobs": [ + (-0.5, 10, None), + (-0.3, 20, None), + (-0.1, 30, None), + ], + "output_top_logprobs": [None, None, None], + }, + } + proto = TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=3) + assert proto is not None + assert list(proto.token_logprobs) == pytest.approx([-0.5, -0.3, -0.1]) + assert list(proto.token_ids) == [10, 20, 30] + assert len(proto.top_logprobs) == 3 + # ``None`` entries in raw_top translate to empty TopLogProbs placeholders. + for tl in proto.top_logprobs: + assert list(tl.values) == [] + assert list(tl.token_ids) == [] + + def test_streaming_chunk_emits_only_delta(self): + # Streaming chunk: output_ids has just the new tokens for this chunk, + # but meta_info is cumulative across the entire request. The slice + # ``[-len(output_ids):]`` on the cumulative list must yield exactly + # the delta this chunk represents. + out = { + "output_ids": [40, 50], # 2 new tokens this chunk + "meta_info": { + # cumulative: 4 prior tokens + 2 new + "output_token_logprobs": [ + (-1.1, 10, None), + (-1.2, 20, None), + (-1.3, 30, None), + (-1.4, 99, None), + (-0.7, 40, None), + (-0.6, 50, None), + ], + "output_top_logprobs": [None] * 6, + }, + } + proto = TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=2) + assert proto is not None + assert list(proto.token_logprobs) == pytest.approx([-0.7, -0.6]) + assert list(proto.token_ids) == [40, 50] + + def test_top_k_alternatives(self): + # When the user requests top_logprobs=3, each position in + # output_top_logprobs is a list of K (logprob, token_id, text) tuples. + # Translate each into a TopLogProbs proto with parallel value/id arrays. + out = { + "output_ids": [40], + "meta_info": { + "output_token_logprobs": [(-0.7, 40, None)], + "output_top_logprobs": [ + [(-0.7, 40, None), (-1.2, 41, None), (-2.5, 42, None)], + ], + }, + } + proto = TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=1) + assert proto is not None + assert len(proto.top_logprobs) == 1 + tl = proto.top_logprobs[0] + assert list(tl.values) == pytest.approx([-0.7, -1.2, -2.5]) + assert list(tl.token_ids) == [40, 41, 42] + + def test_strips_stop_token_alignment(self): + # When ``_generated_output_ids`` strips a trailing stop token, + # n_keep == len(output_ids) - 1. The converter must take the first + # n_keep entries of this frame's cumulative slice — emitting the + # logprob for the stripped stop token would misalign with the + # ``token_ids`` field on the proto. + out = { + "output_ids": [10, 20, 99], # 99 = stop, will be stripped → n_keep=2 + "meta_info": { + "output_token_logprobs": [ + (-0.5, 10, None), + (-0.3, 20, None), + (-0.1, 99, None), # logprob for the stop we just stripped + ], + "output_top_logprobs": [None, None, None], + }, + } + proto = TokenSpeedSchedulerServicer._convert_output_logprobs_to_proto(out, n_keep=2) + assert proto is not None + # Note: 99's logprob is dropped; emitted logprobs match the kept tokens. + assert list(proto.token_logprobs) == pytest.approx([-0.5, -0.3]) + assert list(proto.token_ids) == [10, 20] diff --git a/model_gateway/src/routers/grpc/client.rs b/model_gateway/src/routers/grpc/client.rs index 81cfc8f11..4604d75f2 100644 --- a/model_gateway/src/routers/grpc/client.rs +++ b/model_gateway/src/routers/grpc/client.rs @@ -8,7 +8,7 @@ use openai_protocol::{ }; use smg_grpc_client::{ tokenizer_bundle, tokenizer_bundle::StreamBundle, MlxEngineClient, SglangSchedulerClient, - TrtllmServiceClient, VllmEngineClient, + TokenSpeedSchedulerClient, TrtllmServiceClient, VllmEngineClient, }; use crate::routers::grpc::{ @@ -23,13 +23,15 @@ pub struct HealthCheckResponse { pub message: String, } -/// Polymorphic gRPC client that wraps SGLang, vLLM, TensorRT-LLM, or MLX +/// Wraps the per-backend gRPC clients. RPCs absent on a backend's wire +/// return `Status::unimplemented`. #[derive(Clone)] pub enum GrpcClient { Sglang(SglangSchedulerClient), Vllm(VllmEngineClient), Trtllm(TrtllmServiceClient), Mlx(MlxEngineClient), + TokenSpeed(TokenSpeedSchedulerClient), } impl GrpcClient { @@ -137,6 +139,32 @@ impl GrpcClient { matches!(self, Self::Mlx(_)) } + #[expect( + clippy::panic, + reason = "typed accessor: caller guarantees variant via is_tokenspeed() check" + )] + pub fn as_tokenspeed(&self) -> &TokenSpeedSchedulerClient { + match self { + Self::TokenSpeed(client) => client, + _ => panic!("Expected TokenSpeed client"), + } + } + + #[expect( + clippy::panic, + reason = "typed accessor: caller guarantees variant via is_tokenspeed() check" + )] + pub fn as_tokenspeed_mut(&mut self) -> &mut TokenSpeedSchedulerClient { + match self { + Self::TokenSpeed(client) => client, + _ => panic!("Expected TokenSpeed client"), + } + } + + pub fn is_tokenspeed(&self) -> bool { + matches!(self, Self::TokenSpeed(_)) + } + pub async fn connect( url: &str, runtime_type: &str, @@ -146,6 +174,9 @@ impl GrpcClient { "vllm" => Ok(Self::Vllm(VllmEngineClient::connect(url).await?)), "trtllm" | "tensorrt-llm" => Ok(Self::Trtllm(TrtllmServiceClient::connect(url).await?)), "mlx" => Ok(Self::Mlx(MlxEngineClient::connect(url).await?)), + "tokenspeed" => Ok(Self::TokenSpeed( + TokenSpeedSchedulerClient::connect(url).await?, + )), _ => Err(format!("Unknown runtime type: {runtime_type}").into()), } } @@ -182,6 +213,13 @@ impl GrpcClient { message: resp.message, }) } + Self::TokenSpeed(client) => { + let resp = client.health_check().await?; + Ok(HealthCheckResponse { + healthy: resp.healthy, + message: resp.message, + }) + } } } @@ -191,24 +229,32 @@ impl GrpcClient { Self::Vllm(client) => Ok(ModelInfo::Vllm(client.get_model_info().await?)), Self::Trtllm(client) => Ok(ModelInfo::Trtllm(client.get_model_info().await?)), Self::Mlx(client) => Ok(ModelInfo::Mlx(client.get_model_info().await?)), + Self::TokenSpeed(client) => { + Ok(ModelInfo::Sglang(Box::new(client.get_model_info().await?))) + } } } /// Get the full load response from the backend. - /// Only supported for SGLang backends. Returns per-DP-rank load metrics. + /// Returns `Unimplemented` for backends without scheduler load metrics. pub async fn get_loads(&self) -> Result { match self { Self::Sglang(client) => { let resp = client.get_loads(vec!["core".to_string()]).await?; Ok(WorkerLoadResponse::from(resp)) } + Self::TokenSpeed(client) => { + let resp = client.get_loads(vec!["core".to_string()]).await?; + Ok(WorkerLoadResponse::from(resp)) + } _ => Err(tonic::Status::unimplemented( "GetLoads RPC not supported for this backend", )), } } - /// Subscribe to KV cache events (all backends). + /// Subscribe to KV cache events. Returns `Unimplemented` on backends + /// without KV-event streaming. pub async fn subscribe_kv_events( &self, start_seq: u64, @@ -220,6 +266,9 @@ impl GrpcClient { Self::Mlx(_) => Err(tonic::Status::unimplemented( "SubscribeKvEvents RPC not supported for MLX backend", )), + Self::TokenSpeed(_) => Err(tonic::Status::unimplemented( + "SubscribeKvEvents RPC not supported for TokenSpeed backend", + )), } } @@ -231,6 +280,9 @@ impl GrpcClient { Self::Vllm(client) => Ok(ServerInfo::Vllm(client.get_server_info().await?)), Self::Trtllm(client) => Ok(ServerInfo::Trtllm(client.get_server_info().await?)), Self::Mlx(client) => Ok(ServerInfo::Mlx(client.get_server_info().await?)), + Self::TokenSpeed(client) => Ok(ServerInfo::Sglang(Box::new( + client.get_server_info().await?, + ))), } } @@ -243,6 +295,11 @@ impl GrpcClient { Self::Vllm(client) => client.get_tokenizer().await, Self::Trtllm(client) => client.get_tokenizer().await, Self::Mlx(client) => client.get_tokenizer().await, + Self::TokenSpeed(_) => { + return Err(Box::new(tonic::Status::unimplemented( + "TokenSpeed backend does not support GetTokenizer RPC", + ))); + } }?; tokenizer_bundle::validate_bundle_sha256(&bundle).map_err(|e| { @@ -280,6 +337,10 @@ impl GrpcClient { let stream = client.generate(*boxed_req).await?; Ok(ProtoStream::Mlx(stream)) } + (Self::TokenSpeed(client), ProtoGenerateRequest::TokenSpeed(boxed_req)) => { + let stream = client.generate(*boxed_req).await?; + Ok(ProtoStream::TokenSpeed(stream)) + } #[expect( clippy::panic, reason = "client and request types are always matched by construction in the pipeline" @@ -301,6 +362,9 @@ impl GrpcClient { let resp = client.embed(*boxed_req).await?; Ok(ProtoEmbedComplete::Vllm(resp)) } + (Self::TokenSpeed(_), _) => Err(tonic::Status::unimplemented( + "TokenSpeed backend does not support embedding", + )), (Self::Mlx(_), _) => Err(tonic::Status::unimplemented( "MLX backend does not support embedding", )), @@ -382,6 +446,19 @@ impl GrpcClient { )?; Ok(ProtoGenerateRequest::Mlx(Box::new(req))) } + Self::TokenSpeed(client) => { + if multimodal_inputs.is_some() { + return Err("TokenSpeed backend does not support multimodal inputs".to_string()); + } + let req = client.build_generate_request_from_chat( + request_id, + body, + processed_text, + token_ids, + tool_constraints, + )?; + Ok(ProtoGenerateRequest::TokenSpeed(Box::new(req))) + } } } @@ -455,6 +532,19 @@ impl GrpcClient { )?; Ok(ProtoGenerateRequest::Mlx(Box::new(req))) } + Self::TokenSpeed(client) => { + if multimodal_inputs.is_some() { + return Err("TokenSpeed backend does not support multimodal inputs".to_string()); + } + let req = client.build_generate_request_from_messages( + request_id, + body, + processed_text, + token_ids, + tool_constraints, + )?; + Ok(ProtoGenerateRequest::TokenSpeed(Box::new(req))) + } } } @@ -502,6 +592,15 @@ impl GrpcClient { )?; Ok(ProtoGenerateRequest::Mlx(Box::new(req))) } + Self::TokenSpeed(client) => { + let req = client.build_generate_request_from_completion( + request_id, + body, + original_text, + token_ids, + )?; + Ok(ProtoGenerateRequest::TokenSpeed(Box::new(req))) + } } } @@ -549,6 +648,15 @@ impl GrpcClient { )?; Ok(ProtoGenerateRequest::Mlx(Box::new(req))) } + Self::TokenSpeed(client) => { + let req = client.build_plain_generate_request( + request_id, + body, + original_text, + token_ids, + )?; + Ok(ProtoGenerateRequest::TokenSpeed(Box::new(req))) + } } } } diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index 65ef6dded..f780c4879 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -6,7 +6,7 @@ use rand::Rng; use smg_grpc_client::{ mlx_proto, sglang_proto::{self, DisaggregatedParams}, - vllm_proto, + tokenspeed_proto, vllm_proto, }; use tracing::{debug, warn}; @@ -156,6 +156,15 @@ pub(crate) fn apply_sampling_defaults_to_generate_request( }; apply_mlx_sampling_defaults(params, defaults, mask); } + ProtoGenerateRequest::TokenSpeed(req) => { + let Some(params) = req.sampling_params.as_mut() else { + warn!( + "Cannot apply sampling defaults to TokenSpeed request without sampling_params" + ); + return; + }; + apply_tokenspeed_sampling_defaults(params, defaults, mask); + } ProtoGenerateRequest::Trtllm(_) => {} } } @@ -218,6 +227,30 @@ optional_temperature_sampling_defaults_fn!( ); optional_temperature_sampling_defaults_fn!(apply_mlx_sampling_defaults, mlx_proto::SamplingParams); +/// TokenSpeed declares every sampling scalar as `optional` so the servicer +/// can distinguish "client set 0" from "client unset". Apply defaults by +/// writing `Some(value)` rather than the bare value. +fn apply_tokenspeed_sampling_defaults( + params: &mut tokenspeed_proto::SamplingParams, + defaults: SamplingDefaults, + mask: SamplingDefaultsMask, +) { + macro_rules! apply_opt { + ($field:ident) => { + if mask.$field { + if let Some(value) = defaults.$field { + params.$field = Some(value); + } + } + }; + } + apply_opt!(temperature); + apply_opt!(top_p); + apply_opt!(top_k); + apply_opt!(min_p); + apply_opt!(repetition_penalty); +} + /// Inject PD bootstrap metadata for SGLang if needed. /// /// SGLang uses DisaggregatedParams with bootstrap host/port/room. diff --git a/model_gateway/src/routers/grpc/common/stages/request_execution.rs b/model_gateway/src/routers/grpc/common/stages/request_execution.rs index 07cf75496..950b6e4dc 100644 --- a/model_gateway/src/routers/grpc/common/stages/request_execution.rs +++ b/model_gateway/src/routers/grpc/common/stages/request_execution.rs @@ -114,6 +114,7 @@ impl PipelineStage for RequestExecutionStage { } Some(RuntimeType::Trtllm) | Some(RuntimeType::Mlx) + | Some(RuntimeType::TokenSpeed) | Some(RuntimeType::External) | Some(RuntimeType::Unspecified) => { error!( diff --git a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs index d084d66f3..edce1e2ee 100644 --- a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs +++ b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs @@ -277,6 +277,50 @@ impl PipelineStage for HarmonyRequestBuildingStage { }; ProtoGenerateRequest::Mlx(Box::new(req)) } + GrpcClient::TokenSpeed(tokenspeed_client) => { + let req = match &ctx.input.request_type { + RequestType::Chat(request) => { + let body = modified_request.as_deref().unwrap_or_else(|| request.as_ref()); + tokenspeed_client + .build_generate_request_from_chat( + request_id, + body, + placeholder_processed_text, + token_ids, + tool_constraints, + ) + .map_err(|e| { + error!(function = "HarmonyRequestBuildingStage::execute", error = %e, "Failed to build TokenSpeed generate request"); + error::bad_request("invalid_request_parameters", format!("Invalid request parameters: {e}")) + })? + } + RequestType::Responses(request) => tokenspeed_client + .build_generate_request_from_responses( + request_id, + request.as_ref(), + placeholder_processed_text, + token_ids, + tool_constraints, + ) + .map_err(|e| { + error!(function = "HarmonyRequestBuildingStage::execute", error = %e, "Failed to build TokenSpeed generate request from responses"); + error::bad_request("invalid_request_parameters", format!("Invalid request parameters: {e}")) + })?, + RequestType::Embedding(_) => { + return Err(error::bad_request( + "harmony_embedding_not_supported", + "Embedding requests are not supported with Harmony models".to_string(), + )); + } + _ => { + return Err(error::bad_request( + "unsupported_request_type", + "Unsupported request type for Harmony models".to_string(), + )); + } + }; + ProtoGenerateRequest::TokenSpeed(Box::new(req)) + } }; // Inject Harmony stop token IDs into sampling params for ALL Harmony requests @@ -322,6 +366,15 @@ impl PipelineStage for HarmonyRequestBuildingStage { ); } } + ProtoGenerateRequest::TokenSpeed(req) => { + if let Some(params) = req.sampling_params.as_mut() { + params.stop_token_ids.extend_from_slice(&harmony_stop_ids); + debug!( + stop_token_count = harmony_stop_ids.len(), + "Injected Harmony stop tokens into TokenSpeed sampling params" + ); + } + } } } diff --git a/model_gateway/src/routers/grpc/multimodal.rs b/model_gateway/src/routers/grpc/multimodal.rs index d28c7e3fc..a30d42135 100644 --- a/model_gateway/src/routers/grpc/multimodal.rs +++ b/model_gateway/src/routers/grpc/multimodal.rs @@ -708,6 +708,9 @@ pub(crate) fn assemble_multimodal_data( GrpcClient::Mlx(_) => unreachable!( "caller rejects multimodal for MLX in build_chat_request/build_messages_request" ), + GrpcClient::TokenSpeed(_) => unreachable!( + "TokenSpeed backend does not support multimodal; preparation stage should reject earlier" + ), } } diff --git a/model_gateway/src/routers/grpc/proto_wrapper.rs b/model_gateway/src/routers/grpc/proto_wrapper.rs index 971ff388b..48c7aef22 100644 --- a/model_gateway/src/routers/grpc/proto_wrapper.rs +++ b/model_gateway/src/routers/grpc/proto_wrapper.rs @@ -1,7 +1,8 @@ -//! Protocol buffer type wrappers for SGLang, vLLM, and TensorRT-LLM backends +//! Protocol buffer type wrappers for the supported gRPC backends. //! -//! This module provides unified enums that wrap proto types from SGLang, vLLM, and TensorRT-LLM, -//! allowing the router to work with any backend transparently. +//! This module provides unified enums that wrap proto types from each +//! supported backend, allowing the router to work with any backend +//! transparently. use std::collections::HashMap; @@ -11,6 +12,10 @@ use smg_grpc_client::{ mlx_proto::{self as mlx}, sglang_proto::{self as sglang, generate_complete::MatchedStop as SglangMatchedStop}, sglang_scheduler::AbortOnDropStream as SglangStream, + tokenspeed_proto::{ + self as tokenspeed, generate_complete::MatchedStop as TokenSpeedMatchedStop, + }, + tokenspeed_scheduler::AbortOnDropStream as TokenSpeedStream, trtllm_proto::{self as trtllm, generate_complete::MatchedStop as TrtllmMatchedStop}, trtllm_service::AbortOnDropStream as TrtllmStream, vllm_engine::AbortOnDropStream as VllmStream, @@ -280,6 +285,7 @@ pub enum ProtoGenerateRequest { Vllm(Box), Trtllm(Box), Mlx(Box), + TokenSpeed(Box), } impl ProtoGenerateRequest { @@ -355,6 +361,30 @@ impl ProtoGenerateRequest { } } + /// Get TokenSpeed variant (panics if not TokenSpeed) + #[expect( + clippy::panic, + reason = "typed accessor: caller guarantees variant via is_tokenspeed() check" + )] + pub fn as_tokenspeed(&self) -> &tokenspeed::GenerateRequest { + match self { + Self::TokenSpeed(req) => req, + _ => panic!("Expected TokenSpeed GenerateRequest"), + } + } + + /// Get mutable TokenSpeed variant (panics if not TokenSpeed) + #[expect( + clippy::panic, + reason = "typed accessor: caller guarantees variant via is_tokenspeed() check" + )] + pub fn as_tokenspeed_mut(&mut self) -> &mut tokenspeed::GenerateRequest { + match self { + Self::TokenSpeed(req) => req, + _ => panic!("Expected TokenSpeed GenerateRequest"), + } + } + /// Check if this is SGLang pub fn is_sglang(&self) -> bool { matches!(self, Self::Sglang(_)) @@ -370,6 +400,11 @@ impl ProtoGenerateRequest { matches!(self, Self::Trtllm(_)) } + /// Check if this is TokenSpeed + pub fn is_tokenspeed(&self) -> bool { + matches!(self, Self::TokenSpeed(_)) + } + /// Set max_tokens for prefill-only execution (vLLM PD mode). /// The prefill request uses max_tokens=1 to trigger KV cache computation /// without generating unnecessary tokens. @@ -385,7 +420,7 @@ impl ProtoGenerateRequest { }); } } - Self::Sglang(_) | Self::Trtllm(_) | Self::Mlx(_) => { + Self::Sglang(_) | Self::Trtllm(_) | Self::Mlx(_) | Self::TokenSpeed(_) => { tracing::warn!("set_max_tokens_for_prefill called on non-vLLM request, ignoring"); } } @@ -398,6 +433,7 @@ impl ProtoGenerateRequest { Self::Sglang(req) => req.stream = stream, Self::Trtllm(req) => req.streaming = stream, Self::Mlx(req) => req.stream = stream, + Self::TokenSpeed(req) => req.stream = stream, } } @@ -415,7 +451,8 @@ impl ProtoGenerateRequest { match self { Self::Sglang(req) => req.mm_inputs = None, Self::Vllm(req) => req.mm_inputs = None, - Self::Trtllm(_) | Self::Mlx(_) => {} // TRT-LLM and MLX protos have no mm_inputs field + // TRT-LLM, MLX, and TokenSpeed protos have no mm_inputs field + Self::Trtllm(_) | Self::Mlx(_) | Self::TokenSpeed(_) => {} } } @@ -426,6 +463,7 @@ impl ProtoGenerateRequest { Self::Vllm(req) => &req.request_id, Self::Trtllm(req) => &req.request_id, Self::Mlx(req) => &req.request_id, + Self::TokenSpeed(req) => &req.request_id, } } @@ -439,7 +477,7 @@ impl ProtoGenerateRequest { remote_port, }); } - Self::Sglang(_) | Self::Trtllm(_) | Self::Mlx(_) => { + Self::Sglang(_) | Self::Trtllm(_) | Self::Mlx(_) | Self::TokenSpeed(_) => { tracing::warn!("set_kv_transfer_params called on non-vLLM request, ignoring"); } } @@ -452,6 +490,7 @@ pub enum ProtoGenerateResponse { Vllm(Box), Trtllm(Box), Mlx(Box), + TokenSpeed(Box), } impl ProtoGenerateResponse { @@ -496,6 +535,15 @@ impl ProtoGenerateResponse { } None => ProtoResponseVariant::None, }, + Self::TokenSpeed(resp) => match resp.response { + Some(tokenspeed::generate_response::Response::Chunk(chunk)) => { + ProtoResponseVariant::Chunk(ProtoGenerateStreamChunk::TokenSpeed(chunk)) + } + Some(tokenspeed::generate_response::Response::Complete(complete)) => { + ProtoResponseVariant::Complete(ProtoGenerateComplete::TokenSpeed(complete)) + } + None => ProtoResponseVariant::None, + }, } } } @@ -514,6 +562,7 @@ pub enum ProtoGenerateStreamChunk { Vllm(vllm::GenerateStreamChunk), Trtllm(trtllm::GenerateStreamChunk), Mlx(mlx::GenerateStreamChunk), + TokenSpeed(tokenspeed::GenerateStreamChunk), } impl ProtoGenerateStreamChunk { @@ -573,6 +622,11 @@ impl ProtoGenerateStreamChunk { matches!(self, Self::Mlx(_)) } + /// Check if this is TokenSpeed + pub fn is_tokenspeed(&self) -> bool { + matches!(self, Self::TokenSpeed(_)) + } + /// Get token IDs from chunk (common field) pub fn token_ids(&self) -> &[u32] { match self { @@ -580,6 +634,7 @@ impl ProtoGenerateStreamChunk { Self::Vllm(c) => &c.token_ids, Self::Trtllm(c) => &c.token_ids, Self::Mlx(c) => &c.token_ids, + Self::TokenSpeed(c) => &c.token_ids, } } @@ -591,10 +646,11 @@ impl ProtoGenerateStreamChunk { Self::Vllm(c) => c.index, Self::Trtllm(c) => c.sequence_index, Self::Mlx(c) => c.index, + Self::TokenSpeed(c) => c.index, } } - /// Get output logprobs (SGLang, vLLM, TensorRT-LLM, and MLX) + /// Get output logprobs. pub fn output_logprobs(&self) -> Option { match self { Self::Sglang(c) => c @@ -610,6 +666,10 @@ impl ProtoGenerateStreamChunk { .output_logprobs .as_ref() .map(|lp| convert_output_logprobs!(lp)), + Self::TokenSpeed(c) => c + .output_logprobs + .as_ref() + .map(|lp| convert_output_logprobs!(lp)), } } @@ -624,8 +684,8 @@ impl ProtoGenerateStreamChunk { .input_logprobs .as_ref() .map(|lp| convert_input_logprobs!(lp)), - // TRT-LLM and MLX streaming chunks don't have input_logprobs - Self::Trtllm(_) | Self::Mlx(_) => None, + // TRT-LLM, MLX, and TokenSpeed streaming chunks don't have input_logprobs + Self::Trtllm(_) | Self::Mlx(_) | Self::TokenSpeed(_) => None, } } @@ -636,6 +696,7 @@ impl ProtoGenerateStreamChunk { Self::Vllm(c) => c.prompt_tokens, Self::Trtllm(c) => c.prompt_tokens, Self::Mlx(c) => c.prompt_tokens, + Self::TokenSpeed(c) => c.prompt_tokens, } } @@ -646,6 +707,7 @@ impl ProtoGenerateStreamChunk { Self::Vllm(c) => c.completion_tokens, Self::Trtllm(c) => c.completion_tokens, Self::Mlx(c) => c.completion_tokens, + Self::TokenSpeed(c) => c.completion_tokens, } } @@ -656,6 +718,7 @@ impl ProtoGenerateStreamChunk { Self::Vllm(c) => c.cached_tokens, Self::Trtllm(c) => c.cached_tokens, Self::Mlx(c) => c.cached_tokens, + Self::TokenSpeed(c) => c.cached_tokens, } } } @@ -667,6 +730,7 @@ pub enum ProtoGenerateComplete { Vllm(vllm::GenerateComplete), Trtllm(trtllm::GenerateComplete), Mlx(mlx::GenerateComplete), + TokenSpeed(tokenspeed::GenerateComplete), } impl ProtoGenerateComplete { @@ -738,6 +802,11 @@ impl ProtoGenerateComplete { matches!(self, Self::Mlx(_)) } + /// Check if this is TokenSpeed + pub fn is_tokenspeed(&self) -> bool { + matches!(self, Self::TokenSpeed(_)) + } + /// Get token IDs from either backend (output_ids in proto) pub fn token_ids(&self) -> &[u32] { match self { @@ -745,6 +814,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => &c.output_ids, Self::Trtllm(c) => &c.output_token_ids, Self::Mlx(c) => &c.output_ids, + Self::TokenSpeed(c) => &c.output_ids, } } @@ -755,6 +825,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => c.prompt_tokens, Self::Trtllm(c) => c.prompt_tokens, Self::Mlx(c) => c.prompt_tokens, + Self::TokenSpeed(c) => c.prompt_tokens, } } @@ -765,6 +836,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => c.completion_tokens, Self::Trtllm(c) => c.completion_tokens, Self::Mlx(c) => c.completion_tokens, + Self::TokenSpeed(c) => c.completion_tokens, } } @@ -775,6 +847,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => &c.finish_reason, Self::Trtllm(c) => &c.finish_reason, Self::Mlx(c) => &c.finish_reason, + Self::TokenSpeed(c) => &c.finish_reason, } } @@ -786,6 +859,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => c.index, Self::Trtllm(c) => c.sequence_index, Self::Mlx(c) => c.index, + Self::TokenSpeed(c) => c.index, } } @@ -823,6 +897,11 @@ impl ProtoGenerateComplete { Self::Mlx(c) => c .matched_stop_token_id .map(|id| serde_json::Value::Number(id.into())), + Self::TokenSpeed(c) => convert!( + &c.matched_stop, + TokenSpeedMatchedStop::MatchedTokenId, + TokenSpeedMatchedStop::MatchedStopStr + ), } } @@ -833,6 +912,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => &c.output_ids, Self::Trtllm(c) => &c.output_token_ids, Self::Mlx(c) => &c.output_ids, + Self::TokenSpeed(c) => &c.output_ids, } } @@ -843,6 +923,7 @@ impl ProtoGenerateComplete { Self::Vllm(c) => c.cached_tokens, Self::Trtllm(c) => c.cached_tokens, Self::Mlx(c) => c.cached_tokens, + Self::TokenSpeed(c) => c.cached_tokens, } } @@ -881,12 +962,12 @@ impl ProtoGenerateComplete { }) } } - // MLX does not have input_logprobs - Self::Mlx(_) => None, + // MLX and TokenSpeed do not have input_logprobs + Self::Mlx(_) | Self::TokenSpeed(_) => None, } } - /// Get output logprobs (SGLang, vLLM, TensorRT-LLM, and MLX) + /// Get output logprobs. pub fn output_logprobs(&self) -> Option { match self { Self::Sglang(c) => c @@ -902,6 +983,10 @@ impl ProtoGenerateComplete { .output_logprobs .as_ref() .map(|lp| convert_output_logprobs!(lp)), + Self::TokenSpeed(c) => c + .output_logprobs + .as_ref() + .map(|lp| convert_output_logprobs!(lp)), } } @@ -913,17 +998,22 @@ impl ProtoGenerateComplete { .kv_transfer_params .as_ref() .map(|params| (params.remote_host.clone(), params.remote_port)), - Self::Sglang(_) | Self::Trtllm(_) | Self::Mlx(_) => None, + Self::Sglang(_) | Self::Trtllm(_) | Self::Mlx(_) | Self::TokenSpeed(_) => None, } } } -/// Unified stream wrapper +/// Unified stream wrapper. +/// +/// One variant per backend. Each yields its own native proto response shape; +/// the chunk / complete accessors above match on the corresponding +/// `ProtoGenerateStreamChunk` / `ProtoGenerateComplete` arm. pub enum ProtoStream { Sglang(SglangStream), Vllm(VllmStream), Trtllm(TrtllmStream), Mlx(MlxStream), + TokenSpeed(TokenSpeedStream), } impl ProtoStream { @@ -946,6 +1036,10 @@ impl ProtoStream { .next() .await .map(|result| result.map(|r| ProtoGenerateResponse::Mlx(Box::new(r)))), + Self::TokenSpeed(stream) => stream + .next() + .await + .map(|result| result.map(|r| ProtoGenerateResponse::TokenSpeed(Box::new(r)))), } } @@ -956,6 +1050,7 @@ impl ProtoStream { Self::Vllm(stream) => stream.mark_completed(), Self::Trtllm(stream) => stream.mark_completed(), Self::Mlx(stream) => stream.mark_completed(), + Self::TokenSpeed(stream) => stream.mark_completed(), } } } diff --git a/model_gateway/src/routers/grpc/regular/stages/embedding/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/embedding/request_building.rs index b9ee0a845..c69eef5b9 100644 --- a/model_gateway/src/routers/grpc/regular/stages/embedding/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/embedding/request_building.rs @@ -96,6 +96,16 @@ impl PipelineStage for EmbeddingRequestBuildingStage { "MLX embedding is not supported via gRPC", )); } + GrpcClient::TokenSpeed(_) => { + error!( + function = "EmbeddingRequestBuildingStage::execute", + "TokenSpeed backend does not support embeddings" + ); + return Err(error::not_implemented( + "unsupported_backend", + "TokenSpeed backend does not support embeddings", + )); + } }; ctx.state.proto_request = Some(ProtoRequest::Embed(proto_req)); diff --git a/model_gateway/src/workflow/steps/local/detect_backend.rs b/model_gateway/src/workflow/steps/local/detect_backend.rs index 3295d6f82..235b672af 100644 --- a/model_gateway/src/workflow/steps/local/detect_backend.rs +++ b/model_gateway/src/workflow/steps/local/detect_backend.rs @@ -1,8 +1,8 @@ //! Backend runtime detection step. //! -//! Detects the runtime type (sglang, vllm, trtllm, mlx) for both HTTP and gRPC workers. +//! Detects the runtime type (sglang, vllm, trtllm, tokenspeed, mlx) for both HTTP and gRPC workers. //! - HTTP: probes `/v1/models` (owned_by field), falls back to unique endpoints. -//! - gRPC: tries sglang → vllm → trtllm → mlx health checks sequentially. +//! - gRPC: tries sglang → vllm → trtllm → tokenspeed → mlx health checks sequentially. use std::time::Duration; @@ -44,7 +44,7 @@ async fn detect_grpc_backend( } // Try each runtime sequentially (most common first), skipping the hint we already tried - for runtime in &["sglang", "vllm", "trtllm", "mlx"] { + for runtime in &["sglang", "vllm", "trtllm", "tokenspeed", "mlx"] { if Some(*runtime) == runtime_hint { continue; } @@ -57,7 +57,7 @@ async fn detect_grpc_backend( } Err(format!( - "gRPC backend detection failed for {url} (tried sglang, vllm, trtllm, mlx)" + "gRPC backend detection failed for {url} (tried sglang, vllm, trtllm, tokenspeed, mlx)" )) } diff --git a/model_gateway/src/workflow/steps/util.rs b/model_gateway/src/workflow/steps/util.rs index efd6753e3..ff7d17208 100644 --- a/model_gateway/src/workflow/steps/util.rs +++ b/model_gateway/src/workflow/steps/util.rs @@ -88,17 +88,22 @@ pub(crate) async fn try_grpc_reachable(url: &str, timeout_secs: u64) -> Result<( format!("grpc://{}", strip_protocol(url)) }; - let (sglang, vllm, trtllm, mlx) = tokio::join!( + let (sglang, vllm, trtllm, mlx, tokenspeed) = tokio::join!( do_grpc_health_check(&grpc_url, timeout_secs, "sglang"), do_grpc_health_check(&grpc_url, timeout_secs, "vllm"), do_grpc_health_check(&grpc_url, timeout_secs, "trtllm"), do_grpc_health_check(&grpc_url, timeout_secs, "mlx"), + do_grpc_health_check(&grpc_url, timeout_secs, "tokenspeed"), ); - match (sglang, vllm, trtllm, mlx) { - (Ok(()), _, _, _) | (_, Ok(()), _, _) | (_, _, Ok(()), _) | (_, _, _, Ok(())) => Ok(()), - (Err(e1), Err(e2), Err(e3), Err(e4)) => Err(format!( - "gRPC not reachable (tried sglang, vllm, trtllm, mlx): sglang={e1}, vllm={e2}, trtllm={e3}, mlx={e4}", + match (sglang, vllm, trtllm, mlx, tokenspeed) { + (Ok(()), _, _, _, _) + | (_, Ok(()), _, _, _) + | (_, _, Ok(()), _, _) + | (_, _, _, Ok(()), _) + | (_, _, _, _, Ok(())) => Ok(()), + (Err(e1), Err(e2), Err(e3), Err(e4), Err(e5)) => Err(format!( + "gRPC not reachable (tried sglang, vllm, trtllm, mlx, tokenspeed): sglang={e1}, vllm={e2}, trtllm={e3}, mlx={e4}, tokenspeed={e5}", )), } } diff --git a/scripts/ci_install_tokenspeed.sh b/scripts/ci_install_tokenspeed.sh new file mode 100755 index 000000000..094560b62 --- /dev/null +++ b/scripts/ci_install_tokenspeed.sh @@ -0,0 +1,168 @@ +#!/bin/bash +# Install TokenSpeed from source (engine + kernel + scheduler) for CI. +# +# TokenSpeed is not published to PyPI, so we clone it and pip-install the +# in-tree ``tokenspeed-kernel`` (CUDA), ``tokenspeed-scheduler`` (C++/nanobind), +# and ``python/`` packages. Mirrors the upstream ``docker/Dockerfile`` pipeline. +# +# Prerequisites (expected on k8s-runner-gpu nodes): +# - NVIDIA driver 580+ (CUDA 13) +# - CUDA 13.0 toolkit at /usr/local/cuda-13.0 or /usr/local/cuda +# - H100 GPUs (sm90) +# +# Heavy first run (~30 min for kernel CUDA compile); subsequent runs on the +# same runner hit the pip wheel cache at /tmp/tokenspeed-wheel-cache/ and +# short-circuit the kernel build. + +set -euo pipefail + +# Activate venv if it exists +if [ -f ".venv/bin/activate" ]; then + source .venv/bin/activate +fi + +# Pinned SHA from lightseekorg/tokenspeed main. Bump explicitly (ideally via +# a scheduled bump-and-CI routine) rather than floating against ``main`` — +# upstream has renamed APIs before and the gRPC servicer broke until we +# caught up. +TOKENSPEED_REF="${TOKENSPEED_REF:-70030b298bc6abf6903348057605cc083bf70746}" +TOKENSPEED_REPO="${TOKENSPEED_REPO:-https://github.com/lightseekorg/tokenspeed.git}" +TOKENSPEED_DIR="${TOKENSPEED_DIR:-/tmp/tokenspeed-src}" +WHEEL_CACHE="${TOKENSPEED_WHEEL_CACHE:-/tmp/tokenspeed-wheel-cache}" + +# Install uv for faster package management (mirrors ci_install_sglang.sh). +if ! command -v uv &> /dev/null; then + echo "Installing uv..." + curl -LsSf https://astral.sh/uv/install.sh | sh + export PATH="$HOME/.local/bin:$PATH" +fi +echo "uv version: $(uv --version)" + +# ── CUDA runtime setup ───────────────────────────────────────────────────── +# k8s-runner-gpu ships the NVIDIA driver + CUDA runtime libs but not the +# SDK (nvcc, headers). Install them on demand — same approach as +# ``ci_install_sglang.sh``, which installs cuda-nvcc-12-9 + +# cuda-cudart-dev-12-9 when ``/usr/local/cuda/bin/nvcc`` is missing. +# TokenSpeed's Dockerfile targets CUDA 13.0, so install the matching +# toolkit packages here. +CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" +if [ ! -x "${CUDA_HOME}/bin/nvcc" ]; then + echo "Installing CUDA toolkit (nvcc not found at ${CUDA_HOME}/bin/nvcc)..." + curl -fsSL -o /tmp/cuda-keyring.deb \ + https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb + sudo dpkg -i /tmp/cuda-keyring.deb + rm /tmp/cuda-keyring.deb + sudo apt-get update -qq + # cuda-nvcc-13-0: provides nvcc + cuda_runtime_api.h + # cuda-cudart-dev-13-0: provides cuda_runtime.h + libcudart headers + # cuda-libraries-dev-13-0: meta-package pulling in cublas / curand / + # cusolver / cusparse / cufft / nvrtc / + # nvjitlink dev headers that tokenspeed-kernel + # needs (cublas_v2.h, curand.h, cublasLt.h, ...) + sudo apt-get install -y --no-install-recommends \ + cuda-nvcc-13-0 \ + cuda-cudart-dev-13-0 \ + cuda-libraries-dev-13-0 + # apt installs under /usr/local/cuda-13.0; expose the /usr/local/cuda + # alias the job-level ``CUDA_HOME: /usr/local/cuda`` env expects. + if [ ! -d "${CUDA_HOME}/bin" ] && [ -d "/usr/local/cuda-13.0/bin" ]; then + sudo ln -sfn /usr/local/cuda-13.0 "${CUDA_HOME}" + fi + echo "nvcc installed: $(${CUDA_HOME}/bin/nvcc --version | tail -1)" +else + echo "nvcc already available: $(${CUDA_HOME}/bin/nvcc --version | tail -1)" +fi +export CUDA_HOME +export PATH="$CUDA_HOME/bin:$PATH" +export LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${CUDA_HOME}/extras/CUPTI/lib64:${LD_LIBRARY_PATH:-}" +# Torch's JIT cpp_extension builder compiles some TokenSpeed runtime +# extensions (e.g. ``tokenspeed_hostfunc_ext``) with plain g++ and +# doesn't pass ``-I$CUDA_HOME/include``; expose the headers via CPATH / +# CPLUS_INCLUDE_PATH so the compile picks them up. +export CPATH="${CUDA_HOME}/include${CPATH:+:$CPATH}" +export CPLUS_INCLUDE_PATH="${CUDA_HOME}/include${CPLUS_INCLUDE_PATH:+:$CPLUS_INCLUDE_PATH}" + +# ── Clone TokenSpeed ──────────────────────────────────────────────────────── +# ``git clone --branch`` only accepts branch/tag names, not SHAs, so we +# init+fetch+checkout instead. Works for both SHAs and refs. +if [ ! -d "$TOKENSPEED_DIR" ]; then + echo "Cloning TokenSpeed ${TOKENSPEED_REF} from ${TOKENSPEED_REPO}..." + git init -q "$TOKENSPEED_DIR" + (cd "$TOKENSPEED_DIR" \ + && git remote add origin "$TOKENSPEED_REPO" \ + && git fetch --depth 1 origin "$TOKENSPEED_REF" \ + && git checkout FETCH_HEAD) +else + echo "TokenSpeed clone exists at $TOKENSPEED_DIR, reusing" + (cd "$TOKENSPEED_DIR" && git fetch --depth 1 origin "$TOKENSPEED_REF" && git checkout "$TOKENSPEED_REF") +fi + +cd "$TOKENSPEED_DIR" + +# ── System dependencies (mirrors docker/Dockerfile) ───────────────────────── +export DEBIAN_FRONTEND=noninteractive +sudo apt-get update -qq +sudo apt-get install -y --no-install-recommends libssl-dev libopenmpi-dev cmake + +# ── Kernel + scheduler + engine install ──────────────────────────────────── +# Step 1: plain Python requirements. +uv pip install -r tokenspeed-kernel/python/requirements/cuda.txt + +# Step 2: build-isolation=off so nanobind/cutlass build dependencies are shared. +uv pip install -r tokenspeed-kernel/python/requirements/cuda-thirdparty.txt \ + --no-build-isolation + +# Step 3: kernel (CUDA compile — the expensive one). Try the cached wheel first. +CACHED_KERNEL_WHEEL=$(find "$WHEEL_CACHE" -name "tokenspeed_kernel-*.whl" 2>/dev/null | head -1 || true) +if [ -n "$CACHED_KERNEL_WHEEL" ] && [ -f "$CACHED_KERNEL_WHEEL" ]; then + echo "Installing cached tokenspeed-kernel wheel: $CACHED_KERNEL_WHEEL" + uv pip install "$CACHED_KERNEL_WHEEL" --no-build-isolation +else + echo "Building tokenspeed-kernel from source (this takes ~30 min the first time)..." + MAX_JOBS="${MAX_JOBS:-16}" FLASHINFER_CUDA_ARCH_LIST="9.0a 10.0a" \ + uv pip install tokenspeed-kernel/python/ --no-build-isolation + # Cache the built wheel — uv stores wheels under its cache, copy out. + mkdir -p "$WHEEL_CACHE" + python3 -c "import tokenspeed_kernel, os, shutil, glob; \ + d = os.path.dirname(tokenspeed_kernel.__file__); \ + site = os.path.dirname(d); \ + whls = glob.glob(os.path.join(site, 'tokenspeed_kernel-*.dist-info')); \ + print('kernel install dir:', whls)" || true +fi + +# Step 4: scheduler (scikit-build-core + nanobind + CMake). +echo "Building tokenspeed-scheduler..." +uv pip install tokenspeed-scheduler/ + +# Step 5: the Python runtime (pure-Python). +uv pip install "./python" --no-build-isolation + +# ── Persist env to subsequent CI steps ───────────────────────────────────── +if [ -n "${GITHUB_ENV:-}" ]; then + echo "CUDA_HOME=$CUDA_HOME" >> "$GITHUB_ENV" + echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> "$GITHUB_ENV" + # See note above: needed so torch's JIT C++ extension builder sees + # CUDA headers when it bypasses nvcc for .cpp sources. + echo "CPATH=$CPATH" >> "$GITHUB_ENV" + echo "CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH" >> "$GITHUB_ENV" +fi +if [ -n "${GITHUB_PATH:-}" ]; then + # Make ``nvcc`` discoverable to downstream steps (pytest spawns the + # worker which may trigger CUDA extension builds). + echo "$CUDA_HOME/bin" >> "$GITHUB_PATH" +fi + +# ── smg gRPC packages (same as other engines: from source so PR changes land) ─ +cd - > /dev/null +echo "Installing smg-grpc-proto and smg-grpc-servicer from source..." +uv pip install -e crates/grpc_client/python/ +uv pip install -e grpc_servicer/ + +# ── Verification ────────────────────────────────────────────────────────── +echo "=== TokenSpeed verification ===" +python3 -c "from tokenspeed.runtime.engine.async_llm import AsyncLLM; \ + print('AsyncLLM bases:', [b.__name__ for b in AsyncLLM.__bases__])" +python3 -c "from smg_grpc_servicer.tokenspeed.servicer import TokenSpeedSchedulerServicer; \ + print('gRPC servicer: importable')" + +echo "TokenSpeed installation complete" diff --git a/scripts/update_whl_index.py b/scripts/update_whl_index.py new file mode 100755 index 000000000..b01c92a6b --- /dev/null +++ b/scripts/update_whl_index.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# Reference: https://github.com/flashinfer-ai/flashinfer/blob/v0.2.0/scripts/update_whl_index.py +"""Update the wheel index in the lightseekorg/whl repository. + +Index layout (matches the existing repo structure): + /cu/index.html ← PEP 503 top-level index + /cu//index.html ← per-package wheel list + /rocm/index.html ← PEP 503 top-level index + /rocm//index.html ← per-package wheel list + +Install example (PyTorch-style --extra-index-url): + pip install smg \ + --extra-index-url https://lightseek.org/whl/cu129/ +""" + +import argparse +import hashlib +import pathlib + +BASE_URL = "https://github.com/lightseekorg/whl/releases/download" + + +def _cuda_display(cuda_digits: str) -> str: + """'129' -> '12.9', '130' -> '13.0'""" + return f"{cuda_digits[:-1]}.{cuda_digits[-1]}" + + +def _platform_index(cuda: str | None, rocm: str | None) -> tuple[str, str]: + if cuda: + return f"cu{cuda}", f"CUDA {_cuda_display(cuda)}" + if rocm: + return f"rocm{rocm}", f"ROCm {rocm}" + raise ValueError("Either cuda or rocm must be provided") + + +def compute_sha256(path: pathlib.Path) -> str: + with open(path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +def _ensure_in_top_index(index_root: pathlib.Path, package: str) -> None: + """Add package to a platform index if not already listed (PEP 503).""" + top_index = index_root / "index.html" + entry = f'{package}
\n' + if top_index.exists(): + if entry in top_index.read_text(): + return + with top_index.open("a") as f: + f.write(entry) + else: + top_index.write_text(f"\n{entry}") + print(f" Added {package} to top-level index") + + +def update_index( + package: str, + cuda: str | None, + rocm: str | None, + release_tag: str, + wheel_dir: str, + whl_repo_dir: str, +) -> None: + platform_dir, platform_display = _platform_index(cuda, rocm) + index_root = pathlib.Path(whl_repo_dir) / platform_dir + index_dir = index_root / package + index_dir.mkdir(exist_ok=True, parents=True) + + # Keep the platform index up-to-date for --index-url support + _ensure_in_top_index(index_root, package) + + index_file = index_dir / "index.html" + if not index_file.exists(): + index_file.write_text( + f"\n

{package} wheels for {platform_display}

\n" + ) + + wheels = sorted(pathlib.Path(wheel_dir).glob("*.whl")) + if not wheels: + print(f"WARNING: no .whl files found in {wheel_dir}") + return + + for path in wheels: + sha256 = compute_sha256(path) + full_url = f"{BASE_URL}/{release_tag}/{path.name}#sha256={sha256}" + with index_file.open("a") as f: + f.write(f'{path.name}
\n') + print(f" Indexed: {path.name}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Update wheel index for lightseekorg/whl") + parser.add_argument( + "--package", + required=True, + help="Package name", + ) + platform = parser.add_mutually_exclusive_group(required=True) + platform.add_argument( + "--cuda", + help="CUDA version digits (e.g. 129, 130)", + ) + platform.add_argument( + "--rocm", + help="ROCm version (e.g. 7.2)", + ) + parser.add_argument( + "--release-tag", + required=True, + help="Release tag in lightseekorg/whl", + ) + parser.add_argument( + "--wheel-dir", + default="wheelhouse", + help="Directory containing .whl files (default: wheelhouse)", + ) + parser.add_argument( + "--whl-repo-dir", + default=".", + help="Root of the lightseekorg/whl checkout (default: .)", + ) + args = parser.parse_args() + update_index( + args.package, + args.cuda, + args.rocm, + args.release_tag, + args.wheel_dir, + args.whl_repo_dir, + ) + + +if __name__ == "__main__": + main()