Skip to content

Commit f7905e7

Browse files
authored
Merge pull request #648 from Chuhan1112/fix/iogpu-prepare-count-underflow-557
fix: prevent IOKit prepare count underflow with concurrent completions (#557)
2 parents a2ca277 + ac531fd commit f7905e7

1 file changed

Lines changed: 44 additions & 22 deletions

File tree

omlx/scheduler.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,17 @@ def __init__(
562562
# IOKit's asynchronous completeMemory() callbacks, causing
563563
# 'prepare count underflow' kernel panics. Deferring the clear
564564
# by a few generation steps gives IOKit time to process callbacks.
565-
# None = no deferred clear pending; int = steps since last finish.
566-
self._deferred_clear_steps: Optional[int] = None
565+
#
566+
# Stored as the absolute step number at which the clear should fire,
567+
# rather than a countdown integer. This avoids the burst-completion
568+
# bug (#557): with max_num_seqs > 1 two requests can finish in the
569+
# same batch. The old "only set if None" guard meant the second
570+
# completion never extended the window, so the first request's KV
571+
# cache blocks could be re-allocated before IOKit finished its
572+
# completeMemory() callbacks. Using max() ensures the window always
573+
# covers the *latest* completion.
574+
# None = no deferred clear pending; int = step at which to fire.
575+
self._deferred_clear_at: Optional[int] = None
567576

568577
# Cache XTC special tokens (newline + EOS) — stable per tokenizer.
569578
# Must be after _is_harmony_model / _generation_config_eos init
@@ -2455,9 +2464,12 @@ def _try_specprefill_scoring(self, request: Request) -> None:
24552464
except Exception as e:
24562465
logger.debug(f"SpecPrefill: draft cache store failed: {e}")
24572466

2458-
# Free draft cache from memory
2467+
# Free draft cache from memory. Use _sync_and_clear_cache() so
2468+
# the generation_stream is drained before Metal buffers are
2469+
# returned to the pool — a bare mx.clear_cache() here can race
2470+
# with in-flight async evals and trigger a kernel panic (#557).
24592471
del used_cache
2460-
mx.clear_cache()
2472+
_sync_and_clear_cache()
24612473

24622474
except Exception as e:
24632475
logger.error(f"SpecPrefill scoring failed, falling back to normal path: {e}")
@@ -2651,11 +2663,11 @@ def has_requests(self) -> bool:
26512663
26522664
Also returns True when a deferred Metal cache clear is pending,
26532665
so that the engine loop keeps calling step() until the clear fires.
2654-
Without this, an idle server would never increment the deferred
2655-
counter and stale buffers would accumulate indefinitely.
2666+
Without this, an idle server would never reach the target step and
2667+
stale buffers would accumulate indefinitely.
26562668
"""
26572669
return bool(self.waiting or self.running
2658-
or self._deferred_clear_steps is not None)
2670+
or self._deferred_clear_at is not None)
26592671

26602672
def fail_all_requests(self) -> List[str]:
26612673
"""Remove all running and waiting requests after unrecoverable error.
@@ -2952,7 +2964,14 @@ def _schedule_waiting(
29522964
self.model(sys_arr[:step][None], cache=sp_cache)
29532965
mx.eval([c.state for c in sp_cache])
29542966
sys_arr = sys_arr[step:]
2955-
mx.clear_cache()
2967+
# Use _sync_and_clear_cache() instead of bare
2968+
# mx.clear_cache() to flush the generation_stream
2969+
# before releasing Metal buffers. A bare call here
2970+
# can race with in-flight command buffers submitted
2971+
# by the preceding mx.eval(), triggering the same
2972+
# 'completeMemory() prepare count underflow' kernel
2973+
# panic that #435 fixed elsewhere (#557).
2974+
_sync_and_clear_cache()
29562975
if sys_arr.size > 0:
29572976
self.model(sys_arr[None], cache=sp_cache)
29582977
mx.eval([c.state for c in sp_cache])
@@ -3471,10 +3490,16 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
34713490
# Deferring by _DEFERRED_CLEAR_DELAY generation steps (~10-40 ms) gives
34723491
# IOKit time to process callbacks while still reclaiming buffers fast
34733492
# enough to prevent TTFT spikes from pool bloat (#411).
3474-
# Only set if not already pending — otherwise burst completions
3475-
# would keep resetting the counter and indefinitely postpone clearing.
3476-
if self._deferred_clear_steps is None:
3477-
self._deferred_clear_steps = 0
3493+
#
3494+
# Use max() so that concurrent completions (max_num_seqs > 1) each get
3495+
# a full _DEFERRED_CLEAR_DELAY window counted from *their own* finish
3496+
# step. The old "only set if None" guard meant the second request's
3497+
# window was anchored to the first request's finish step, allowing the
3498+
# second request's KV cache blocks to be re-allocated before IOKit
3499+
# finished their completeMemory() callbacks (#557).
3500+
target = self._step_counter + self._DEFERRED_CLEAR_DELAY
3501+
if self._deferred_clear_at is None or target > self._deferred_clear_at:
3502+
self._deferred_clear_at = target
34783503

34793504
def _is_cache_corruption_error(self, error: Exception) -> bool:
34803505
"""Check if an error indicates cache corruption."""
@@ -3507,7 +3532,7 @@ def _recover_from_cache_error(self) -> None:
35073532
self.uid_to_request_id.clear()
35083533

35093534
# Cancel any pending deferred Metal cache clear
3510-
self._deferred_clear_steps = None
3535+
self._deferred_clear_at = None
35113536

35123537
# Clear detokenizer state to prevent contamination after recovery
35133538
self._request_detokenizers.clear()
@@ -3689,14 +3714,11 @@ def step(self) -> SchedulerOutput:
36893714
and self._step_counter % self.config.mlx_cache_cleanup_interval == 0
36903715
):
36913716
should_clear = True
3692-
# Deferred post-completion cleanup: wait _DEFERRED_CLEAR_DELAY steps
3693-
# after the last request completion to give IOKit time to process
3694-
# completeMemory() callbacks before releasing Metal buffers (#435).
3695-
if self._deferred_clear_steps is not None:
3696-
self._deferred_clear_steps += 1
3697-
if self._deferred_clear_steps >= self._DEFERRED_CLEAR_DELAY:
3698-
should_clear = True
3699-
self._deferred_clear_steps = None
3717+
# Deferred post-completion cleanup: fire once the step counter reaches
3718+
# the target set by _cleanup_finished() (#435, #557).
3719+
if self._deferred_clear_at is not None and self._step_counter >= self._deferred_clear_at:
3720+
should_clear = True
3721+
self._deferred_clear_at = None
37003722
if should_clear:
37013723
_sync_and_clear_cache()
37023724
if (
@@ -3768,7 +3790,7 @@ def reset(self) -> None:
37683790
self._output_parser_sessions.clear()
37693791

37703792
# Cancel any pending deferred Metal cache clear
3771-
self._deferred_clear_steps = None
3793+
self._deferred_clear_at = None
37723794

37733795
def deep_reset(self) -> None:
37743796
"""

0 commit comments

Comments
 (0)