diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 97f46132..bafa7570 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -553,8 +553,17 @@ def __init__( # IOKit's asynchronous completeMemory() callbacks, causing # 'prepare count underflow' kernel panics. Deferring the clear # by a few generation steps gives IOKit time to process callbacks. - # None = no deferred clear pending; int = steps since last finish. - self._deferred_clear_steps: Optional[int] = None + # + # Stored as the absolute step number at which the clear should fire, + # rather than a countdown integer. This avoids the burst-completion + # bug (#557): with max_num_seqs > 1 two requests can finish in the + # same batch. The old "only set if None" guard meant the second + # completion never extended the window, so the first request's KV + # cache blocks could be re-allocated before IOKit finished its + # completeMemory() callbacks. Using max() ensures the window always + # covers the *latest* completion. + # None = no deferred clear pending; int = step at which to fire. + self._deferred_clear_at: Optional[int] = None # Cache XTC special tokens (newline + EOS) — stable per tokenizer. # Must be after _is_harmony_model / _generation_config_eos init @@ -2412,9 +2421,12 @@ def _try_specprefill_scoring(self, request: Request) -> None: except Exception as e: logger.debug(f"SpecPrefill: draft cache store failed: {e}") - # Free draft cache from memory + # Free draft cache from memory. Use _sync_and_clear_cache() so + # the generation_stream is drained before Metal buffers are + # returned to the pool — a bare mx.clear_cache() here can race + # with in-flight async evals and trigger a kernel panic (#557). del used_cache - mx.clear_cache() + _sync_and_clear_cache() except Exception as e: logger.error(f"SpecPrefill scoring failed, falling back to normal path: {e}") @@ -2606,11 +2618,11 @@ def has_requests(self) -> bool: Also returns True when a deferred Metal cache clear is pending, so that the engine loop keeps calling step() until the clear fires. - Without this, an idle server would never increment the deferred - counter and stale buffers would accumulate indefinitely. + Without this, an idle server would never reach the target step and + stale buffers would accumulate indefinitely. """ return bool(self.waiting or self.running - or self._deferred_clear_steps is not None) + or self._deferred_clear_at is not None) def fail_all_requests(self) -> List[str]: """Remove all running and waiting requests after unrecoverable error. @@ -2907,7 +2919,14 @@ def _schedule_waiting( self.model(sys_arr[:step][None], cache=sp_cache) mx.eval([c.state for c in sp_cache]) sys_arr = sys_arr[step:] - mx.clear_cache() + # Use _sync_and_clear_cache() instead of bare + # mx.clear_cache() to flush the generation_stream + # before releasing Metal buffers. A bare call here + # can race with in-flight command buffers submitted + # by the preceding mx.eval(), triggering the same + # 'completeMemory() prepare count underflow' kernel + # panic that #435 fixed elsewhere (#557). + _sync_and_clear_cache() if sys_arr.size > 0: self.model(sys_arr[None], cache=sp_cache) mx.eval([c.state for c in sp_cache]) @@ -3405,10 +3424,16 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: # Deferring by _DEFERRED_CLEAR_DELAY generation steps (~10-40 ms) gives # IOKit time to process callbacks while still reclaiming buffers fast # enough to prevent TTFT spikes from pool bloat (#411). - # Only set if not already pending — otherwise burst completions - # would keep resetting the counter and indefinitely postpone clearing. - if self._deferred_clear_steps is None: - self._deferred_clear_steps = 0 + # + # Use max() so that concurrent completions (max_num_seqs > 1) each get + # a full _DEFERRED_CLEAR_DELAY window counted from *their own* finish + # step. The old "only set if None" guard meant the second request's + # window was anchored to the first request's finish step, allowing the + # second request's KV cache blocks to be re-allocated before IOKit + # finished their completeMemory() callbacks (#557). + target = self._step_counter + self._DEFERRED_CLEAR_DELAY + if self._deferred_clear_at is None or target > self._deferred_clear_at: + self._deferred_clear_at = target def _is_cache_corruption_error(self, error: Exception) -> bool: """Check if an error indicates cache corruption.""" @@ -3441,7 +3466,7 @@ def _recover_from_cache_error(self) -> None: self.uid_to_request_id.clear() # Cancel any pending deferred Metal cache clear - self._deferred_clear_steps = None + self._deferred_clear_at = None # Clear detokenizer state to prevent contamination after recovery self._request_detokenizers.clear() @@ -3623,14 +3648,11 @@ def step(self) -> SchedulerOutput: and self._step_counter % self.config.mlx_cache_cleanup_interval == 0 ): should_clear = True - # Deferred post-completion cleanup: wait _DEFERRED_CLEAR_DELAY steps - # after the last request completion to give IOKit time to process - # completeMemory() callbacks before releasing Metal buffers (#435). - if self._deferred_clear_steps is not None: - self._deferred_clear_steps += 1 - if self._deferred_clear_steps >= self._DEFERRED_CLEAR_DELAY: - should_clear = True - self._deferred_clear_steps = None + # Deferred post-completion cleanup: fire once the step counter reaches + # the target set by _cleanup_finished() (#435, #557). + if self._deferred_clear_at is not None and self._step_counter >= self._deferred_clear_at: + should_clear = True + self._deferred_clear_at = None if should_clear: _sync_and_clear_cache() if ( @@ -3702,7 +3724,7 @@ def reset(self) -> None: self._output_parser_sessions.clear() # Cancel any pending deferred Metal cache clear - self._deferred_clear_steps = None + self._deferred_clear_at = None def deep_reset(self) -> None: """