Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 44 additions & 22 deletions omlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,17 @@ def __init__(
# IOKit's asynchronous completeMemory() callbacks, causing
# 'prepare count underflow' kernel panics. Deferring the clear
# by a few generation steps gives IOKit time to process callbacks.
# None = no deferred clear pending; int = steps since last finish.
self._deferred_clear_steps: Optional[int] = None
#
# Stored as the absolute step number at which the clear should fire,
# rather than a countdown integer. This avoids the burst-completion
# bug (#557): with max_num_seqs > 1 two requests can finish in the
# same batch. The old "only set if None" guard meant the second
# completion never extended the window, so the first request's KV
# cache blocks could be re-allocated before IOKit finished its
# completeMemory() callbacks. Using max() ensures the window always
# covers the *latest* completion.
# None = no deferred clear pending; int = step at which to fire.
self._deferred_clear_at: Optional[int] = None

# Cache XTC special tokens (newline + EOS) — stable per tokenizer.
# Must be after _is_harmony_model / _generation_config_eos init
Expand Down Expand Up @@ -2412,9 +2421,12 @@ def _try_specprefill_scoring(self, request: Request) -> None:
except Exception as e:
logger.debug(f"SpecPrefill: draft cache store failed: {e}")

# Free draft cache from memory
# Free draft cache from memory. Use _sync_and_clear_cache() so
# the generation_stream is drained before Metal buffers are
# returned to the pool — a bare mx.clear_cache() here can race
# with in-flight async evals and trigger a kernel panic (#557).
del used_cache
mx.clear_cache()
_sync_and_clear_cache()

except Exception as e:
logger.error(f"SpecPrefill scoring failed, falling back to normal path: {e}")
Expand Down Expand Up @@ -2606,11 +2618,11 @@ def has_requests(self) -> bool:

Also returns True when a deferred Metal cache clear is pending,
so that the engine loop keeps calling step() until the clear fires.
Without this, an idle server would never increment the deferred
counter and stale buffers would accumulate indefinitely.
Without this, an idle server would never reach the target step and
stale buffers would accumulate indefinitely.
"""
return bool(self.waiting or self.running
or self._deferred_clear_steps is not None)
or self._deferred_clear_at is not None)

def fail_all_requests(self) -> List[str]:
"""Remove all running and waiting requests after unrecoverable error.
Expand Down Expand Up @@ -2907,7 +2919,14 @@ def _schedule_waiting(
self.model(sys_arr[:step][None], cache=sp_cache)
mx.eval([c.state for c in sp_cache])
sys_arr = sys_arr[step:]
mx.clear_cache()
# Use _sync_and_clear_cache() instead of bare
# mx.clear_cache() to flush the generation_stream
# before releasing Metal buffers. A bare call here
# can race with in-flight command buffers submitted
# by the preceding mx.eval(), triggering the same
# 'completeMemory() prepare count underflow' kernel
# panic that #435 fixed elsewhere (#557).
_sync_and_clear_cache()
if sys_arr.size > 0:
self.model(sys_arr[None], cache=sp_cache)
mx.eval([c.state for c in sp_cache])
Expand Down Expand Up @@ -3405,10 +3424,16 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
# Deferring by _DEFERRED_CLEAR_DELAY generation steps (~10-40 ms) gives
# IOKit time to process callbacks while still reclaiming buffers fast
# enough to prevent TTFT spikes from pool bloat (#411).
# Only set if not already pending — otherwise burst completions
# would keep resetting the counter and indefinitely postpone clearing.
if self._deferred_clear_steps is None:
self._deferred_clear_steps = 0
#
# Use max() so that concurrent completions (max_num_seqs > 1) each get
# a full _DEFERRED_CLEAR_DELAY window counted from *their own* finish
# step. The old "only set if None" guard meant the second request's
# window was anchored to the first request's finish step, allowing the
# second request's KV cache blocks to be re-allocated before IOKit
# finished their completeMemory() callbacks (#557).
target = self._step_counter + self._DEFERRED_CLEAR_DELAY
if self._deferred_clear_at is None or target > self._deferred_clear_at:
self._deferred_clear_at = target

def _is_cache_corruption_error(self, error: Exception) -> bool:
"""Check if an error indicates cache corruption."""
Expand Down Expand Up @@ -3441,7 +3466,7 @@ def _recover_from_cache_error(self) -> None:
self.uid_to_request_id.clear()

# Cancel any pending deferred Metal cache clear
self._deferred_clear_steps = None
self._deferred_clear_at = None

# Clear detokenizer state to prevent contamination after recovery
self._request_detokenizers.clear()
Expand Down Expand Up @@ -3623,14 +3648,11 @@ def step(self) -> SchedulerOutput:
and self._step_counter % self.config.mlx_cache_cleanup_interval == 0
):
should_clear = True
# Deferred post-completion cleanup: wait _DEFERRED_CLEAR_DELAY steps
# after the last request completion to give IOKit time to process
# completeMemory() callbacks before releasing Metal buffers (#435).
if self._deferred_clear_steps is not None:
self._deferred_clear_steps += 1
if self._deferred_clear_steps >= self._DEFERRED_CLEAR_DELAY:
should_clear = True
self._deferred_clear_steps = None
# Deferred post-completion cleanup: fire once the step counter reaches
# the target set by _cleanup_finished() (#435, #557).
if self._deferred_clear_at is not None and self._step_counter >= self._deferred_clear_at:
should_clear = True
self._deferred_clear_at = None
if should_clear:
_sync_and_clear_cache()
if (
Expand Down Expand Up @@ -3702,7 +3724,7 @@ def reset(self) -> None:
self._output_parser_sessions.clear()

# Cancel any pending deferred Metal cache clear
self._deferred_clear_steps = None
self._deferred_clear_at = None

def deep_reset(self) -> None:
"""
Expand Down