Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions .github/workflows/pr-test-mlx.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
name: PR Test (MLX)

on:
push:
branches: [main]
paths:
- "crates/grpc_client/proto/mlx_engine.proto"
- "crates/grpc_client/src/mlx_engine.rs"
- "crates/grpc_client/python/**"
- "grpc_servicer/smg_grpc_servicer/mlx/**"
- "grpc_servicer/pyproject.toml"
- "e2e_test/mlx/test_mlx_backend.py"
- "e2e_test/infra/__init__.py"
- "e2e_test/infra/model_specs.py"
- "e2e_test/infra/worker.py"
- "e2e_test/infra/constants.py"
Comment on lines +14 to +16
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Add infra init to MLX workflow path filters

The new workflow watches several MLX-related E2E infra files but omits e2e_test/infra/__init__.py, even though this commit changes that module and e2e_test/conftest.py imports from infra at startup. A future PR that only updates infra/__init__.py can break MLX test startup/import resolution without triggering this workflow, so regressions can merge untested.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

- ".github/workflows/pr-test-mlx.yml"
Comment thread
coderabbitai[bot] marked this conversation as resolved.
pull_request:
branches: [main]
types: [opened, synchronize, reopened]
paths:
- "crates/grpc_client/proto/mlx_engine.proto"
- "crates/grpc_client/src/mlx_engine.rs"
- "crates/grpc_client/python/**"
- "grpc_servicer/smg_grpc_servicer/mlx/**"
Comment on lines +22 to +25
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Include router-side MLX paths in workflow trigger

When a PR only changes the MLX branches in the router, e.g. model_gateway/src/routers/grpc/client.rs or proto_wrapper.rs, this workflow will not run because the pull_request.paths list here only covers the proto/client package, servicer, and E2E infra. I checked the regular PR GPU workflow and its reusable E2E matrix is for sglang, vllm, and trtllm, so those MLX router paths are not exercised there either; an MLX-specific routing regression can therefore merge without the new Apple Silicon E2E job.

Useful? React with 👍 / 👎.

- "grpc_servicer/pyproject.toml"
- "e2e_test/mlx/test_mlx_backend.py"
- "e2e_test/infra/__init__.py"
- "e2e_test/infra/model_specs.py"
- "e2e_test/infra/worker.py"
- "e2e_test/infra/constants.py"
Comment on lines +27 to +31
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Include MLX fixture files in path filters

The MLX test module is run through e2e_test/conftest.py and the setup_backend/api_client fixtures in e2e_test/fixtures/setup_backend.py, but this new workflow only watches the test file plus a few infra modules. A PR that changes those shared fixtures can break MLX startup/client wiring without triggering this macOS workflow; I checked the existing GPU workflow filters and they already treat e2e_test/conftest.py and e2e_test/fixtures/** as common E2E inputs, but this new workflow does not.

Useful? React with 👍 / 👎.

- ".github/workflows/pr-test-mlx.yml"
workflow_dispatch:

permissions:
contents: read

concurrency:
group: mlx-tests-${{ github.ref }}
cancel-in-progress: ${{ github.event_name == 'pull_request' }}

jobs:
e2e-mlx:
name: E2E (MLX on Apple Silicon)
runs-on: macos-latest
timeout-minutes: 30
permissions:
contents: read
env:
E2E_RUNTIME: mlx
E2E_ENGINE: mlx
PYTHONUNBUFFERED: "1"
steps:
- name: Checkout code
uses: actions/checkout@v6

- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.12"

- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable

- name: Install protoc
run: brew install protobuf

- name: Cache Cargo registry and target
uses: Swatinem/rust-cache@v2
with:
shared-key: mlx-pr-test
cache-on-failure: true

- name: Install uv (for openapi codegen)
run: pip install uv

- name: Build smg-grpc-proto Python package (proto codegen)
run: pip install -e ./crates/grpc_client/python

- name: Install grpc_servicer with MLX extra (mlx + mlx-lm)
run: pip install -e "./grpc_servicer[mlx]"

- name: Build and install SMG Python bindings (ci profile)
working-directory: bindings/python
run: |
pip install maturin
# `ci` profile (opt-level=2, thin LTO, 16 codegen-units) — faster
# to compile than release, runtime still plenty fast for a
# correctness E2E test.
# Use `maturin build` + `pip install` (not `maturin develop`)
# because the GitHub-hosted runner's Python is not in a virtualenv.
maturin build --profile ci --out dist
pip install dist/*.whl

- name: Generate Python client types (required by e2e_test/conftest.py)
run: make generate-python-types

- name: Install smg-client
run: pip install ./clients/python

- name: Install E2E test dependencies
run: pip install ./e2e_test

- name: Verify imports
run: |
python -c "from smg_grpc_proto import mlx_engine_pb2, mlx_engine_pb2_grpc; print('proto OK')"
python -c "from smg_grpc_servicer.mlx.servicer import MlxEngineServicer; print('servicer OK')"
python -c "import smg; print('smg OK')"
python -c "from smg_client import SmgClient; print('smg_client OK')"
python -c "import mlx_lm; print('mlx-lm OK')"

- name: Run MLX E2E tests
env:
SHOW_WORKER_LOGS: "1"
SHOW_ROUTER_LOGS: "1"
E2E_LOG_DIR: e2e-logs
run: |
pytest e2e_test/mlx/test_mlx_backend.py \
-s -vv \
--reruns 1 --reruns-delay 5

Comment thread
key4ng marked this conversation as resolved.
- name: Upload logs on failure
if: failure() || cancelled()
uses: actions/upload-artifact@v7
with:
name: e2e-mlx-logs
path: e2e-logs/
retention-days: 7
if-no-files-found: ignore
2 changes: 2 additions & 0 deletions e2e_test/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Runtime,
WorkerType,
get_runtime,
is_mlx,
is_sglang,
is_trtllm,
is_vllm,
Expand Down Expand Up @@ -111,6 +112,7 @@
"is_vllm",
"is_sglang",
"is_trtllm",
"is_mlx",
# Port utilities
"get_open_port",
"release_port",
Expand Down
13 changes: 12 additions & 1 deletion e2e_test/infra/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Runtime(StrEnum):
SGLANG = "sglang"
VLLM = "vllm"
TRTLLM = "trtllm"
MLX = "mlx"
OPENAI = "openai"
XAI = "xai"
GEMINI = "gemini"
Expand All @@ -33,7 +34,7 @@ class Runtime(StrEnum):

# Convenience sets
LOCAL_MODES = frozenset({ConnectionMode.HTTP, ConnectionMode.GRPC})
LOCAL_RUNTIMES = frozenset({Runtime.SGLANG, Runtime.VLLM, Runtime.TRTLLM})
LOCAL_RUNTIMES = frozenset({Runtime.SGLANG, Runtime.VLLM, Runtime.TRTLLM, Runtime.MLX})
CLOUD_RUNTIMES = frozenset({Runtime.OPENAI, Runtime.XAI, Runtime.GEMINI, Runtime.ANTHROPIC})

# Fixture parameter names (used in @pytest.mark.parametrize)
Expand Down Expand Up @@ -100,11 +101,21 @@ def is_trtllm() -> bool:
return get_runtime() == "trtllm"


def is_mlx() -> bool:
"""Check if tests are running with MLX runtime (Apple Silicon only).

Returns:
True if E2E_RUNTIME is "mlx", False otherwise.
"""
return get_runtime() == "mlx"


# Runtime display labels
RUNTIME_LABELS = {
"sglang": "SGLang",
"vllm": "vLLM",
"trtllm": "TensorRT-LLM",
"mlx": "MLX",
}

ENV_SHOW_ROUTER_LOGS = "SHOW_ROUTER_LOGS"
Expand Down
10 changes: 10 additions & 0 deletions e2e_test/infra/model_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,16 @@ def _resolve_model_path(hf_path: str) -> str:
"--enable-chunked-prefill",
],
},
# ── MLX models (Apple Silicon only) ──────────────────────────────────────
# Smallest Qwen3 with native tool calling + thinking mode (~400 MB).
# Used by CI on macos-latest runners. Qwen3 emits <tool_call> tags
# parsed by SMG's --tool-call-parser qwen, and uses <think> tags
# parsed by the reasoning parser.
"mlx-community/Qwen3-0.6B-4bit": {
Comment thread
key4ng marked this conversation as resolved.
"model": _resolve_model_path("mlx-community/Qwen3-0.6B-4bit"),
"tp": 1,
"features": ["chat", "streaming", "function_calling", "reasoning", "thinking"],
},
}


Expand Down
26 changes: 26 additions & 0 deletions e2e_test/infra/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def _build_cmd(self) -> list[str]:
return self._build_vllm_http_cmd(model_path, tp_size, spec)
elif self.engine == "trtllm":
return self._build_trtllm_cmd(model_path, tp_size, spec)
elif self.engine == "mlx":
return self._build_mlx_cmd(model_path, spec)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Nit: _build_env (line 328) unconditionally sets CUDA_VISIBLE_DEVICES for every engine, including MLX which runs on Apple Metal — not CUDA. It's harmless on macOS (the var is simply ignored), but if you want to keep things tidy you could skip it for MLX:

if self.engine != "mlx":
    env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, self.gpu_ids))

Not blocking — just a readability thing for the next person who reads _build_env.

else:
Comment thread
key4ng marked this conversation as resolved.
raise ValueError(f"Unsupported engine: {self.engine}")

Expand Down Expand Up @@ -261,6 +263,30 @@ def _build_vllm_base_cmd(
cmd.extend(extra)
return cmd

def _build_mlx_cmd(self, model_path: str, spec: dict) -> list[str]:
"""Build MLX gRPC server command (Apple Silicon only).

MLX backend only supports gRPC mode (no HTTP variant) since the
servicer wraps mlx-lm's BatchGenerator behind the MlxEngine proto.
"""
if self.mode != ConnectionMode.GRPC:
raise ValueError("MLX backend only supports gRPC mode")
cmd = [
"python3",
"-m",
"smg_grpc_servicer.mlx.server",
"--model",
model_path,
"--host",
DEFAULT_HOST,
"--port",
str(self.port),
]
extra = spec.get("mlx_args", [])
if extra:
cmd.extend(extra)
return cmd

def _build_trtllm_cmd(self, model_path: str, tp_size: int, spec: dict) -> list[str]:
"""Build TensorRT-LLM gRPC server command."""
# Create config file to enable xgrammar guided decoding
Expand Down
Empty file added e2e_test/mlx/__init__.py
Empty file.
Loading
Loading