@@ -562,8 +562,17 @@ def __init__(
562562 # IOKit's asynchronous completeMemory() callbacks, causing
563563 # 'prepare count underflow' kernel panics. Deferring the clear
564564 # by a few generation steps gives IOKit time to process callbacks.
565- # None = no deferred clear pending; int = steps since last finish.
566- self ._deferred_clear_steps : Optional [int ] = None
565+ #
566+ # Stored as the absolute step number at which the clear should fire,
567+ # rather than a countdown integer. This avoids the burst-completion
568+ # bug (#557): with max_num_seqs > 1 two requests can finish in the
569+ # same batch. The old "only set if None" guard meant the second
570+ # completion never extended the window, so the first request's KV
571+ # cache blocks could be re-allocated before IOKit finished its
572+ # completeMemory() callbacks. Using max() ensures the window always
573+ # covers the *latest* completion.
574+ # None = no deferred clear pending; int = step at which to fire.
575+ self ._deferred_clear_at : Optional [int ] = None
567576
568577 # Cache XTC special tokens (newline + EOS) — stable per tokenizer.
569578 # Must be after _is_harmony_model / _generation_config_eos init
@@ -2455,9 +2464,12 @@ def _try_specprefill_scoring(self, request: Request) -> None:
24552464 except Exception as e :
24562465 logger .debug (f"SpecPrefill: draft cache store failed: { e } " )
24572466
2458- # Free draft cache from memory
2467+ # Free draft cache from memory. Use _sync_and_clear_cache() so
2468+ # the generation_stream is drained before Metal buffers are
2469+ # returned to the pool — a bare mx.clear_cache() here can race
2470+ # with in-flight async evals and trigger a kernel panic (#557).
24592471 del used_cache
2460- mx . clear_cache ()
2472+ _sync_and_clear_cache ()
24612473
24622474 except Exception as e :
24632475 logger .error (f"SpecPrefill scoring failed, falling back to normal path: { e } " )
@@ -2651,11 +2663,11 @@ def has_requests(self) -> bool:
26512663
26522664 Also returns True when a deferred Metal cache clear is pending,
26532665 so that the engine loop keeps calling step() until the clear fires.
2654- Without this, an idle server would never increment the deferred
2655- counter and stale buffers would accumulate indefinitely.
2666+ Without this, an idle server would never reach the target step and
2667+ stale buffers would accumulate indefinitely.
26562668 """
26572669 return bool (self .waiting or self .running
2658- or self ._deferred_clear_steps is not None )
2670+ or self ._deferred_clear_at is not None )
26592671
26602672 def fail_all_requests (self ) -> List [str ]:
26612673 """Remove all running and waiting requests after unrecoverable error.
@@ -2952,7 +2964,14 @@ def _schedule_waiting(
29522964 self .model (sys_arr [:step ][None ], cache = sp_cache )
29532965 mx .eval ([c .state for c in sp_cache ])
29542966 sys_arr = sys_arr [step :]
2955- mx .clear_cache ()
2967+ # Use _sync_and_clear_cache() instead of bare
2968+ # mx.clear_cache() to flush the generation_stream
2969+ # before releasing Metal buffers. A bare call here
2970+ # can race with in-flight command buffers submitted
2971+ # by the preceding mx.eval(), triggering the same
2972+ # 'completeMemory() prepare count underflow' kernel
2973+ # panic that #435 fixed elsewhere (#557).
2974+ _sync_and_clear_cache ()
29562975 if sys_arr .size > 0 :
29572976 self .model (sys_arr [None ], cache = sp_cache )
29582977 mx .eval ([c .state for c in sp_cache ])
@@ -3471,10 +3490,16 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
34713490 # Deferring by _DEFERRED_CLEAR_DELAY generation steps (~10-40 ms) gives
34723491 # IOKit time to process callbacks while still reclaiming buffers fast
34733492 # enough to prevent TTFT spikes from pool bloat (#411).
3474- # Only set if not already pending — otherwise burst completions
3475- # would keep resetting the counter and indefinitely postpone clearing.
3476- if self ._deferred_clear_steps is None :
3477- self ._deferred_clear_steps = 0
3493+ #
3494+ # Use max() so that concurrent completions (max_num_seqs > 1) each get
3495+ # a full _DEFERRED_CLEAR_DELAY window counted from *their own* finish
3496+ # step. The old "only set if None" guard meant the second request's
3497+ # window was anchored to the first request's finish step, allowing the
3498+ # second request's KV cache blocks to be re-allocated before IOKit
3499+ # finished their completeMemory() callbacks (#557).
3500+ target = self ._step_counter + self ._DEFERRED_CLEAR_DELAY
3501+ if self ._deferred_clear_at is None or target > self ._deferred_clear_at :
3502+ self ._deferred_clear_at = target
34783503
34793504 def _is_cache_corruption_error (self , error : Exception ) -> bool :
34803505 """Check if an error indicates cache corruption."""
@@ -3507,7 +3532,7 @@ def _recover_from_cache_error(self) -> None:
35073532 self .uid_to_request_id .clear ()
35083533
35093534 # Cancel any pending deferred Metal cache clear
3510- self ._deferred_clear_steps = None
3535+ self ._deferred_clear_at = None
35113536
35123537 # Clear detokenizer state to prevent contamination after recovery
35133538 self ._request_detokenizers .clear ()
@@ -3689,14 +3714,11 @@ def step(self) -> SchedulerOutput:
36893714 and self ._step_counter % self .config .mlx_cache_cleanup_interval == 0
36903715 ):
36913716 should_clear = True
3692- # Deferred post-completion cleanup: wait _DEFERRED_CLEAR_DELAY steps
3693- # after the last request completion to give IOKit time to process
3694- # completeMemory() callbacks before releasing Metal buffers (#435).
3695- if self ._deferred_clear_steps is not None :
3696- self ._deferred_clear_steps += 1
3697- if self ._deferred_clear_steps >= self ._DEFERRED_CLEAR_DELAY :
3698- should_clear = True
3699- self ._deferred_clear_steps = None
3717+ # Deferred post-completion cleanup: fire once the step counter reaches
3718+ # the target set by _cleanup_finished() (#435, #557).
3719+ if self ._deferred_clear_at is not None and self ._step_counter >= self ._deferred_clear_at :
3720+ should_clear = True
3721+ self ._deferred_clear_at = None
37003722 if should_clear :
37013723 _sync_and_clear_cache ()
37023724 if (
@@ -3768,7 +3790,7 @@ def reset(self) -> None:
37683790 self ._output_parser_sessions .clear ()
37693791
37703792 # Cancel any pending deferred Metal cache clear
3771- self ._deferred_clear_steps = None
3793+ self ._deferred_clear_at = None
37723794
37733795 def deep_reset (self ) -> None :
37743796 """
0 commit comments