@@ -3138,34 +3138,34 @@ def reshape_for_broadcast(slot, name_suffix):
31383138 )
31393139 )
31403140
3141- # Step 3: inv_std = rsqrt(var_eps)
3142- _ , tmp_inv_std = P .make_tmp_slot ()
3143- P .emit (RsqrtNode (x = P .slot_to_tid (tmp_var_eps ), out = P .slot_to_tid (tmp_inv_std )))
3141+ # Step 3: inv_std = rsqrt(var_eps), written in place so MLX can donate the
3142+ # var_eps buffer (unary, same shape/dtype).
3143+ P .emit (RsqrtNode (x = P .slot_to_tid (tmp_var_eps ), out = P .slot_to_tid (tmp_var_eps )))
3144+ tmp_inv_std = tmp_var_eps
31443145
3145- # Step 4: x_normalized = x_centered * inv_std
3146- _ , tmp_normalized = P . make_tmp_slot ()
3146+ # Step 4: x_normalized = x_centered * inv_std, written in place into the
3147+ # full-size x_centered buffer (the donatable operand; inv_std broadcasts).
31473148 P .emit (
31483149 MultiplyNode (
31493150 a = P .slot_to_tid (tmp_centered ),
31503151 b = P .slot_to_tid (tmp_inv_std ),
3151- out = P .slot_to_tid (tmp_normalized ),
3152+ out = P .slot_to_tid (tmp_centered ),
31523153 )
31533154 )
3155+ tmp_normalized = tmp_centered
31543156
31553157 # Step 5: x_scaled = x_normalized * weight (skip if weight is None, i.e. affine=False)
31563158 if weight is not None :
31573159 weight_reshaped = reshape_for_broadcast (weight , "weight" )
3158- _ , tmp_scaled = P . make_tmp_slot ()
3160+ # In place into the full-size x_normalized buffer (weight broadcasts).
31593161 P .emit (
31603162 MultiplyNode (
31613163 a = P .slot_to_tid (tmp_normalized ),
31623164 b = P .slot_to_tid (weight_reshaped ),
3163- out = P .slot_to_tid (tmp_scaled ),
3165+ out = P .slot_to_tid (tmp_normalized ),
31643166 )
31653167 )
3166- current_result = tmp_scaled
3167- else :
3168- current_result = tmp_normalized
3168+ current_result = tmp_normalized
31693169
31703170 # Step 6: out = current_result + bias (skip if bias is None, i.e. affine=False)
31713171 if bias is not None :
@@ -3558,47 +3558,39 @@ def emit_sample():
35583558 seed_field = P .slot_to_vid (seed_val )
35593559
35603560 # uniform u in [0, 1): bits/uint32_max, clamped just below 1 (random.cpp:95)
3561- _ , bits = P .make_tmp_slot ()
3561+ _ , u = P .make_tmp_slot ()
35623562 P .emit (
3563- RandomBitsNode (
3564- out = P .slot_to_tid (bits ), shape = shape , width = 4 , seed = seed_field
3565- )
3563+ RandomBitsNode (out = P .slot_to_tid (u ), shape = shape , width = 4 , seed = seed_field )
35663564 )
3567- _ , bits_f = P . make_tmp_slot ()
3565+ # u32 -> f32 in place (same itemsize, donatable), then divide/clamp in place.
35683566 P .emit (
35693567 AsTypeNode (
3570- x = P .slot_to_tid (bits ),
3571- out = P .slot_to_tid (bits_f ),
3568+ x = P .slot_to_tid (u ),
3569+ out = P .slot_to_tid (u ),
35723570 scalar_type = torch_dtype_to_scalar_type (torch .float32 ),
35733571 )
35743572 )
35753573 umax = emit_lifted_constant (P , 4294967295.0 , torch .float32 )
3576- _ , div0 = P .make_tmp_slot ()
35773574 P .emit (
3578- DivideNode (
3579- a = P .slot_to_tid (bits_f ), b = P .slot_to_tid (umax ), out = P .slot_to_tid (div0 )
3580- )
3575+ DivideNode (a = P .slot_to_tid (u ), b = P .slot_to_tid (umax ), out = P .slot_to_tid (u ))
35813576 )
35823577 prev1 = emit_lifted_constant (
35833578 P ,
35843579 float (torch .nextafter (torch .tensor (1.0 ), torch .tensor (0.0 ))),
35853580 torch .float32 ,
35863581 )
3587- _ , clamp = P .make_tmp_slot ()
35883582 P .emit (
35893583 MinimumNode (
3590- a = P .slot_to_tid (div0 ), b = P .slot_to_tid (prev1 ), out = P .slot_to_tid (clamp )
3584+ a = P .slot_to_tid (u ), b = P .slot_to_tid (prev1 ), out = P .slot_to_tid (u )
35913585 )
35923586 )
35933587 # gumbel g = -log(-log(u)); whole chain stays fp32 (bf16 mis-ranks ties; clamp->1.0->+inf).
3594- _ , l1 = P .make_tmp_slot ()
3595- P .emit (LogNode (x = P .slot_to_tid (clamp ), out = P .slot_to_tid (l1 )))
3596- _ , g1 = P .make_tmp_slot ()
3597- P .emit (NegNode (x = P .slot_to_tid (l1 ), out = P .slot_to_tid (g1 )))
3598- _ , l2 = P .make_tmp_slot ()
3599- P .emit (LogNode (x = P .slot_to_tid (g1 ), out = P .slot_to_tid (l2 )))
3588+ # All links are single-use unary ops, so reuse one buffer in place.
36003589 _ , g = P .make_tmp_slot ()
3601- P .emit (NegNode (x = P .slot_to_tid (l2 ), out = P .slot_to_tid (g )))
3590+ P .emit (LogNode (x = P .slot_to_tid (u ), out = P .slot_to_tid (g )))
3591+ P .emit (NegNode (x = P .slot_to_tid (g ), out = P .slot_to_tid (g )))
3592+ P .emit (LogNode (x = P .slot_to_tid (g ), out = P .slot_to_tid (g )))
3593+ P .emit (NegNode (x = P .slot_to_tid (g ), out = P .slot_to_tid (g )))
36023594
36033595 # sample: argmax(logits / temperature + g) over the vocab axis, in float32
36043596 _ , logits_f = P .make_tmp_slot ()
@@ -3609,34 +3601,42 @@ def emit_sample():
36093601 scalar_type = torch_dtype_to_scalar_type (torch .float32 ),
36103602 )
36113603 )
3612- _ , scaled = P .make_tmp_slot ()
3604+ # logits_f is single-use; divide in place. The result (scaled) is read
3605+ # twice (softmax below and the final where), so this buffer must live
3606+ # until then.
36133607 P .emit (
36143608 DivideNode (
36153609 a = P .slot_to_tid (logits_f ),
36163610 b = P .slot_to_tid (temperature ),
3617- out = P .slot_to_tid (scaled ),
3611+ out = P .slot_to_tid (logits_f ),
36183612 )
36193613 )
3614+ scaled = logits_f
36203615
36213616 # top-p nucleus mask; SortNode is ascending-only, so sort -probs for descending.
3617+ # probs is read twice (neg_p below and the drop comparison), keep separate.
36223618 _ , probs = P .make_tmp_slot ()
36233619 P .emit (SoftmaxNode (x = P .slot_to_tid (scaled ), out = P .slot_to_tid (probs ), axis = - 1 ))
3624- _ , neg_p = P .make_tmp_slot ()
3625- P .emit (NegNode (x = P .slot_to_tid (probs ), out = P .slot_to_tid (neg_p )))
3626- _ , sorted_neg = P .make_tmp_slot ()
3627- P .emit (SortNode (x = P .slot_to_tid (neg_p ), out = P .slot_to_tid (sorted_neg ), axis = - 1 ))
3620+ # neg_p -> sort -> neg are single-use; thread one buffer.
36283621 _ , sorted_p = P .make_tmp_slot ()
3629- P .emit (NegNode (x = P .slot_to_tid (sorted_neg ), out = P .slot_to_tid (sorted_p )))
3622+ P .emit (NegNode (x = P .slot_to_tid (probs ), out = P .slot_to_tid (sorted_p )))
3623+ P .emit (
3624+ SortNode (x = P .slot_to_tid (sorted_p ), out = P .slot_to_tid (sorted_p ), axis = - 1 )
3625+ )
3626+ # sorted_p is read three times below (cumsum, prefix subtract, kept where),
3627+ # so stop reusing it here.
3628+ P .emit (NegNode (x = P .slot_to_tid (sorted_p ), out = P .slot_to_tid (sorted_p )))
36303629 _ , cum = P .make_tmp_slot ()
36313630 P .emit (CumsumNode (x = P .slot_to_tid (sorted_p ), out = P .slot_to_tid (cum ), axis = - 1 ))
3632- _ , prefix = P . make_tmp_slot ()
3631+ # prefix = cum - sorted_p; cum is single-use, reuse it in place.
36333632 P .emit (
36343633 SubtractNode (
36353634 a = P .slot_to_tid (cum ),
36363635 b = P .slot_to_tid (sorted_p ),
3637- out = P .slot_to_tid (prefix ),
3636+ out = P .slot_to_tid (cum ),
36383637 )
36393638 )
3639+ prefix = cum
36403640 # remove sorted tokens whose prefix mass already exceeds top_p (top-1: 0)
36413641 _ , remove = P .make_tmp_slot ()
36423642 P .emit (
@@ -3675,6 +3675,7 @@ def emit_sample():
36753675 )
36763676 )
36773677 neg_inf = emit_lifted_constant (P , float ("-inf" ), torch .float32 )
3678+ # masked = where(drop, -inf, scaled); then add gumbel noise in place.
36783679 _ , masked = P .make_tmp_slot ()
36793680 P .emit (
36803681 WhereNode (
@@ -3684,16 +3685,14 @@ def emit_sample():
36843685 out = P .slot_to_tid (masked ),
36853686 )
36863687 )
3687-
3688- _ , noisy = P .make_tmp_slot ()
36893688 P .emit (
36903689 AddNode (
3691- a = P .slot_to_tid (masked ), b = P .slot_to_tid (g ), out = P .slot_to_tid (noisy )
3690+ a = P .slot_to_tid (masked ), b = P .slot_to_tid (g ), out = P .slot_to_tid (masked )
36923691 )
36933692 )
36943693 P .emit (
36953694 ArgmaxNode (
3696- x = P .slot_to_tid (noisy ), out = P .slot_to_tid (out ), axis = - 1 , keepdims = False
3695+ x = P .slot_to_tid (masked ), out = P .slot_to_tid (out ), axis = - 1 , keepdims = False
36973696 )
36983697 )
36993698
0 commit comments