Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions pymllm/orchestrator/scheduler_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@
_DEFAULT_MAX_TOTAL_TOKENS = 131072
_DEFAULT_MAX_NEW_TOKENS = 32768

# Brief poll timeout (ms) used between decode batches to avoid 100% CPU spin.
# 1 ms is enough to yield the CPU core to the OS scheduler while adding
# negligible latency (decode steps typically take >1 ms on the GPU anyway).
_DECODE_POLL_TIMEOUT_MS = 1


# ======================================================================
# IdleSleeper -- avoid busy-looping when no work is available
Expand Down Expand Up @@ -482,20 +487,30 @@ def init_model(self) -> None:
logger.info("In-process model runner initialised on GPU %d", self._gpu_id)

def event_loop(self) -> None:
"""Infinite scheduling loop."""
"""Infinite scheduling loop.

When decode batches are active the loop would otherwise spin at
100 % CPU doing non-blocking ZMQ polls between GPU forward passes.
We track whether the previous iteration ran a decode batch and, if
so, use a brief poll timeout (default 1 ms) in ``recv_requests``
so the OS can schedule other work on this core.
"""
logger.info(
"SchedulerProcess event loop started (shared_queue=%s, transport=%s)",
self._enable_shared_queue,
self._tensor_transport_mode,
)
_in_decode = False
while True:
self.recv_requests()
self.recv_requests(brief_poll=_in_decode)
self.process_input_requests()
batch = self.get_next_batch_to_run()
if batch is not None:
_in_decode = not batch.forward_mode.is_extend()
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else:
_in_decode = False
# No work available -- sleep until a new request arrives
# on the ZMQ socket (or timeout). Avoids busy-looping.
self._idle_sleeper.sleep()
Expand All @@ -505,30 +520,40 @@ def event_loop(self) -> None:
# Step 1: receive tokenized requests (non-blocking)
# ------------------------------------------------------------------

def recv_requests(self) -> None:
def recv_requests(self, brief_poll: bool = False) -> None:
"""Non-blocking receive of tokenized requests from TokenizerProcess.

Supports two modes:
1. Legacy ZMQ: Uses ``zmq.Poller`` with a short timeout
2. Shared queue: Non-blocking get from multiprocessing.Queue

When *brief_poll* is ``True`` (typically during active decode), the
first poll uses a small timeout (``_DECODE_POLL_TIMEOUT_MS``) instead
of zero. This yields the CPU core to the OS scheduler between decode
batches while adding negligible latency.

Messages are either:
* A :class:`~pymllm.engine.io_struct.TokenizedGenerateReqInput`
dataclass – appended to ``_waiting_queue``.
* A plain abort sentinel dict ``{"rid": ..., "abort": True}`` – handled
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
inline by removing the matching rid from the waiting queue.
"""
if self._enable_shared_queue and self._shared_queue is not None:
self._recv_from_shared_queue()
self._recv_from_shared_queue(brief_poll=brief_poll)
else:
self._recv_from_zmq()
self._recv_from_zmq(brief_poll=brief_poll)

def _recv_from_zmq(self) -> None:
def _recv_from_zmq(self, brief_poll: bool = False) -> None:
"""Receive requests via legacy ZMQ path."""
# On the first poll, use a brief timeout if requested (decode path)
# to yield the CPU. After draining the first message, switch to
# non-blocking for any remaining queued messages.
poll_timeout = _DECODE_POLL_TIMEOUT_MS if brief_poll else 0
while True:
events = dict(self._poller.poll(timeout=0)) # non-blocking
events = dict(self._poller.poll(timeout=poll_timeout))
if self._recv_from_tokenizer not in events:
break
poll_timeout = 0 # drain remaining messages without blocking
msg = self._recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
# Abort sentinel: plain dict with "abort" key.
if isinstance(msg, dict) and msg.get("abort"):
Expand All @@ -542,7 +567,7 @@ def _recv_from_zmq(self) -> None:
else:
self._waiting_queue.append(msg)

def _recv_from_shared_queue(self) -> None:
def _recv_from_shared_queue(self, brief_poll: bool = False) -> None:
"""Receive requests via shared memory + shared queue fast path.

After reading a ``(rid, shm_name, mm_inputs)`` tuple from the queue:
Expand All @@ -556,9 +581,13 @@ def _recv_from_shared_queue(self) -> None:
3. A full ``TokenizedGenerateReqInput`` is assembled and appended to
``_waiting_queue``.
"""
# Use a slightly longer timeout on the first get when in decode mode
# to yield CPU; subsequent gets use a short timeout to drain the queue.
get_timeout = _DECODE_POLL_TIMEOUT_MS / 1000.0 if brief_poll else 0.002
while True:
try:
rid, shm_name, mm_inputs = self._shared_queue.get(timeout=0.002)
rid, shm_name, mm_inputs = self._shared_queue.get(timeout=get_timeout)
get_timeout = 0.002 # drain remaining without extra delay

# Read metadata from shared memory (and unlink immediately)
metadata: TokenizedGenerateReqInput = SharedMemoryManager.read_metadata(
Expand Down
113 changes: 113 additions & 0 deletions pymllm/tests/bench_cpu_busy_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#!/usr/bin/env python3
"""Benchmark: CPU busy-loop vs brief-poll in the scheduler event loop.

Simulates the scheduler's decode loop (poll → "forward" → poll → ...)
and measures CPU usage under both strategies.

Usage:
python pymllm/tests/bench_cpu_busy_loop.py

What to look for:
- "CPU usage" percentage: spin-poll should be ~100%, brief-poll should be <10%
- "Wall time" should be similar (brief-poll adds ~1ms per iteration)
- "Throughput" (iterations/sec) shows the latency cost of the brief poll
"""

import os
import time

import zmq


def run_loop(poller, sock, poll_timeout_ms: int, duration_s: float = 2.0):
"""Run the scheduler-style poll loop for *duration_s* seconds.

The loop body does NO simulated work — this isolates the poll overhead,
which is exactly what happens in the real scheduler between GPU kernel
launches (the CPU thread is free while the GPU computes; it's the poll
call that either spins or yields).

Returns (wall_time, cpu_time, iterations).
"""
iterations = 0
t0_wall = time.monotonic()
t0_cpu = time.process_time()
deadline = t0_wall + duration_s

while time.monotonic() < deadline:
# Poll for new requests (this is where CPU spins or yields)
timeout = poll_timeout_ms
while True:
events = dict(poller.poll(timeout=timeout))
if sock not in events:
break
timeout = 0 # drain remaining
sock.recv(zmq.NOBLOCK) # consume message
iterations += 1

wall = time.monotonic() - t0_wall
cpu = time.process_time() - t0_cpu
return wall, cpu, iterations


def main():
ctx = zmq.Context()
sock = ctx.socket(zmq.PULL)
addr = f"inproc://bench-{os.getpid()}"
sock.bind(addr)

poller = zmq.Poller()
poller.register(sock, zmq.POLLIN)

duration = 3.0 # seconds per test

print("=" * 64)
print("Scheduler CPU Busy-Loop Benchmark")
print("=" * 64)
print(f"Each test runs for {duration:.0f}s simulating the scheduler poll loop")
print(f"(poll for requests → loop back, no simulated GPU work)")
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
print()

# --- Spin poll (timeout=0) ---
print("Running SPIN POLL (timeout=0) ...")
spin_wall, spin_cpu, spin_iters = run_loop(poller, sock, 0, duration)
spin_pct = 100.0 * spin_cpu / max(spin_wall, 1e-9)
spin_throughput = spin_iters / max(spin_wall, 1e-9)

# --- Brief poll (timeout=1ms) ---
print("Running BRIEF POLL (timeout=1ms) ...")
brief_wall, brief_cpu, brief_iters = run_loop(poller, sock, 1, duration)
brief_pct = 100.0 * brief_cpu / max(brief_wall, 1e-9)
brief_throughput = brief_iters / max(brief_wall, 1e-9)

sock.close()
ctx.term()

# --- Results ---
print()
print("-" * 64)
print(f"{'Metric':<30} {'Spin (before)':>15} {'Brief (after)':>15}")
print("-" * 64)
print(f"{'Wall time (s)':<30} {spin_wall:>15.3f} {brief_wall:>15.3f}")
print(f"{'CPU time (s)':<30} {spin_cpu:>15.3f} {brief_cpu:>15.3f}")
print(f"{'CPU usage (%)':<30} {spin_pct:>14.1f}% {brief_pct:>14.1f}%")
print(f"{'Iterations':<30} {spin_iters:>15d} {brief_iters:>15d}")
print(f"{'Throughput (iter/s)':<30} {spin_throughput:>15.1f} {brief_throughput:>15.1f}")
print("-" * 64)

reduction = spin_pct - brief_pct
throughput_cost = 100.0 * (1 - brief_throughput / max(spin_throughput, 1)) if spin_throughput > 0 else 0
print()
print(f"CPU usage reduction: {reduction:+.1f} percentage points")
print(f"Throughput cost: {throughput_cost:.1f}% fewer iterations/sec")
print()
if reduction > 20:
print("RESULT: Significant CPU savings with negligible throughput cost.")
elif reduction > 5:
print("RESULT: Moderate CPU savings.")
else:
print("RESULT: Minimal difference (forward pass dominates loop time).")


if __name__ == "__main__":
main()
Loading