Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 23 additions & 38 deletions grpc_servicer/smg_grpc_servicer/mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from grpc_reflection.v1alpha import reflection
from huggingface_hub import snapshot_download
from mlx_lm import load
from mlx_lm.generate import BatchGenerator
from smg_grpc_proto import mlx_engine_pb2, mlx_engine_pb2_grpc

from smg_grpc_servicer.mlx.health_servicer import MlxHealthServicer
Expand Down Expand Up @@ -85,39 +84,12 @@ def load_model(args):
return model, tokenizer, model_dir, model_config, eos_token_ids


def _warmup(batch_generator):
"""Run one end-to-end token through the batch generator so the first
real request doesn't pay JIT/kernel compilation cost."""
logger.info("Running warmup generation...")
try:
uids = batch_generator.insert(prompts=[[1]], max_tokens=[1])
for _ in range(10):
_, gen_responses = batch_generator.next()
if any(r.finish_reason is not None for r in gen_responses if r.uid == uids[0]):
break
batch_generator.remove(uids)
logger.info("Warmup complete")
except Exception:
logger.warning("Warmup failed (non-fatal)", exc_info=True)


async def serve_grpc(args):
"""Start the MLX gRPC server."""
start_time = time.time()

model, tokenizer, model_dir, model_config, eos_token_ids = load_model(args)

batch_generator = BatchGenerator(
model,
completion_batch_size=args.completion_batch_size,
prefill_batch_size=args.prefill_batch_size,
)
logger.info(
"BatchGenerator created (prefill=%d, completion=%d)",
args.prefill_batch_size,
args.completion_batch_size,
)

server = grpc.aio.server(
futures.ThreadPoolExecutor(max_workers=10),
options=[
Expand All @@ -131,8 +103,17 @@ async def serve_grpc(args):
health_servicer = MlxHealthServicer()
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)

# Construct the servicer WITHOUT a BatchGenerator. The BatchGenerator
# (and its thread-local mlx stream) is built on the generation thread
# below, so all mlx state lives on one thread — same model as
# mlx-lm.server. This avoids the cross-thread "no Stream(gpu, 1) in
# current thread" RuntimeError we saw when an mx.async_eval
# continuation tried to look up the stream context on a thread that
# never bound it.
servicer = MlxEngineServicer(
batch_generator=batch_generator,
model=model,
completion_batch_size=args.completion_batch_size,
prefill_batch_size=args.prefill_batch_size,
model_path=args.model,
model_dir=model_dir,
model_config=model_config,
Expand All @@ -153,16 +134,21 @@ async def serve_grpc(args):
if bound_port == 0:
raise RuntimeError(f"Failed to bind gRPC server to {listen_addr}")

# Warmup BEFORE starting the generation loop (batch_generator.next() is
# not thread-safe — only one caller at a time).
_warmup(batch_generator)
# The gen thread does construction → warmup → enters main loop. Wait
# for it to signal ready before flipping the health check to SERVING,
# otherwise a Generate RPC could slip into the window where the gen
# thread hasn't constructed BatchGenerator yet and block forever on
# _pending. wait_ready() returns False if BatchGenerator construction
# raised on the gen thread — fail startup loudly in that case rather
# than advertising a healthy server with a dead gen thread that hangs
# every Generate RPC on _pending.
servicer.start_generation_loop()
loop = asyncio.get_running_loop()
ready = await loop.run_in_executor(None, servicer.wait_ready)
if not ready:
servicer.stop_generation_loop()
raise RuntimeError("MLX generation thread failed to become ready — see preceding logs")

# Only accept RPCs after the generation loop is running. Otherwise a
# Generate RPC could slip into the window between server.start() and
# start_generation_loop() and block forever on queue.get() because no
# gen thread is dispatching tokens. HealthCheck always returns OK, so
# the router can't use it to detect this window.
await server.start()
health_servicer.set_serving()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
logger.info("gRPC server listening on %s — model: %s", listen_addr, args.model)
Expand All @@ -187,7 +173,6 @@ def signal_handler():
# loop first would leave new/in-flight RPCs stranded.
await server.stop(5.0)
servicer.stop_generation_loop()
batch_generator.close()
logger.info("Server stopped")


Expand Down
Loading
Loading