Skip to content

Commit b919db7

Browse files
authored
Reuse buffers in op handlers (#20524)
This PR makes the MLX backend's compound op emitters reuse buffers in place so MLX can donate them, rewriting the batch_norm normalize chain and the sample gumbel/top-p chain to thread results through a small set of temp slots (out==in) instead of allocating a fresh temp per step, while keeping separate slots for multi-use values. An audit of the remaining emitters/handlers (conv, pooling, gated-delta-rule, SDPA, quantized/gguf linear, and the C++ runtime handlers) confirmed they already reuse buffers or are inherently non-donating (views/fused kernels), so no further changes were needed.
1 parent d19320d commit b919db7

1 file changed

Lines changed: 43 additions & 44 deletions

File tree

backends/mlx/ops.py

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3138,34 +3138,34 @@ def reshape_for_broadcast(slot, name_suffix):
31383138
)
31393139
)
31403140

3141-
# Step 3: inv_std = rsqrt(var_eps)
3142-
_, tmp_inv_std = P.make_tmp_slot()
3143-
P.emit(RsqrtNode(x=P.slot_to_tid(tmp_var_eps), out=P.slot_to_tid(tmp_inv_std)))
3141+
# Step 3: inv_std = rsqrt(var_eps), written in place so MLX can donate the
3142+
# var_eps buffer (unary, same shape/dtype).
3143+
P.emit(RsqrtNode(x=P.slot_to_tid(tmp_var_eps), out=P.slot_to_tid(tmp_var_eps)))
3144+
tmp_inv_std = tmp_var_eps
31443145

3145-
# Step 4: x_normalized = x_centered * inv_std
3146-
_, tmp_normalized = P.make_tmp_slot()
3146+
# Step 4: x_normalized = x_centered * inv_std, written in place into the
3147+
# full-size x_centered buffer (the donatable operand; inv_std broadcasts).
31473148
P.emit(
31483149
MultiplyNode(
31493150
a=P.slot_to_tid(tmp_centered),
31503151
b=P.slot_to_tid(tmp_inv_std),
3151-
out=P.slot_to_tid(tmp_normalized),
3152+
out=P.slot_to_tid(tmp_centered),
31523153
)
31533154
)
3155+
tmp_normalized = tmp_centered
31543156

31553157
# Step 5: x_scaled = x_normalized * weight (skip if weight is None, i.e. affine=False)
31563158
if weight is not None:
31573159
weight_reshaped = reshape_for_broadcast(weight, "weight")
3158-
_, tmp_scaled = P.make_tmp_slot()
3160+
# In place into the full-size x_normalized buffer (weight broadcasts).
31593161
P.emit(
31603162
MultiplyNode(
31613163
a=P.slot_to_tid(tmp_normalized),
31623164
b=P.slot_to_tid(weight_reshaped),
3163-
out=P.slot_to_tid(tmp_scaled),
3165+
out=P.slot_to_tid(tmp_normalized),
31643166
)
31653167
)
3166-
current_result = tmp_scaled
3167-
else:
3168-
current_result = tmp_normalized
3168+
current_result = tmp_normalized
31693169

31703170
# Step 6: out = current_result + bias (skip if bias is None, i.e. affine=False)
31713171
if bias is not None:
@@ -3558,47 +3558,39 @@ def emit_sample():
35583558
seed_field = P.slot_to_vid(seed_val)
35593559

35603560
# uniform u in [0, 1): bits/uint32_max, clamped just below 1 (random.cpp:95)
3561-
_, bits = P.make_tmp_slot()
3561+
_, u = P.make_tmp_slot()
35623562
P.emit(
3563-
RandomBitsNode(
3564-
out=P.slot_to_tid(bits), shape=shape, width=4, seed=seed_field
3565-
)
3563+
RandomBitsNode(out=P.slot_to_tid(u), shape=shape, width=4, seed=seed_field)
35663564
)
3567-
_, bits_f = P.make_tmp_slot()
3565+
# u32 -> f32 in place (same itemsize, donatable), then divide/clamp in place.
35683566
P.emit(
35693567
AsTypeNode(
3570-
x=P.slot_to_tid(bits),
3571-
out=P.slot_to_tid(bits_f),
3568+
x=P.slot_to_tid(u),
3569+
out=P.slot_to_tid(u),
35723570
scalar_type=torch_dtype_to_scalar_type(torch.float32),
35733571
)
35743572
)
35753573
umax = emit_lifted_constant(P, 4294967295.0, torch.float32)
3576-
_, div0 = P.make_tmp_slot()
35773574
P.emit(
3578-
DivideNode(
3579-
a=P.slot_to_tid(bits_f), b=P.slot_to_tid(umax), out=P.slot_to_tid(div0)
3580-
)
3575+
DivideNode(a=P.slot_to_tid(u), b=P.slot_to_tid(umax), out=P.slot_to_tid(u))
35813576
)
35823577
prev1 = emit_lifted_constant(
35833578
P,
35843579
float(torch.nextafter(torch.tensor(1.0), torch.tensor(0.0))),
35853580
torch.float32,
35863581
)
3587-
_, clamp = P.make_tmp_slot()
35883582
P.emit(
35893583
MinimumNode(
3590-
a=P.slot_to_tid(div0), b=P.slot_to_tid(prev1), out=P.slot_to_tid(clamp)
3584+
a=P.slot_to_tid(u), b=P.slot_to_tid(prev1), out=P.slot_to_tid(u)
35913585
)
35923586
)
35933587
# gumbel g = -log(-log(u)); whole chain stays fp32 (bf16 mis-ranks ties; clamp->1.0->+inf).
3594-
_, l1 = P.make_tmp_slot()
3595-
P.emit(LogNode(x=P.slot_to_tid(clamp), out=P.slot_to_tid(l1)))
3596-
_, g1 = P.make_tmp_slot()
3597-
P.emit(NegNode(x=P.slot_to_tid(l1), out=P.slot_to_tid(g1)))
3598-
_, l2 = P.make_tmp_slot()
3599-
P.emit(LogNode(x=P.slot_to_tid(g1), out=P.slot_to_tid(l2)))
3588+
# All links are single-use unary ops, so reuse one buffer in place.
36003589
_, g = P.make_tmp_slot()
3601-
P.emit(NegNode(x=P.slot_to_tid(l2), out=P.slot_to_tid(g)))
3590+
P.emit(LogNode(x=P.slot_to_tid(u), out=P.slot_to_tid(g)))
3591+
P.emit(NegNode(x=P.slot_to_tid(g), out=P.slot_to_tid(g)))
3592+
P.emit(LogNode(x=P.slot_to_tid(g), out=P.slot_to_tid(g)))
3593+
P.emit(NegNode(x=P.slot_to_tid(g), out=P.slot_to_tid(g)))
36023594

36033595
# sample: argmax(logits / temperature + g) over the vocab axis, in float32
36043596
_, logits_f = P.make_tmp_slot()
@@ -3609,34 +3601,42 @@ def emit_sample():
36093601
scalar_type=torch_dtype_to_scalar_type(torch.float32),
36103602
)
36113603
)
3612-
_, scaled = P.make_tmp_slot()
3604+
# logits_f is single-use; divide in place. The result (scaled) is read
3605+
# twice (softmax below and the final where), so this buffer must live
3606+
# until then.
36133607
P.emit(
36143608
DivideNode(
36153609
a=P.slot_to_tid(logits_f),
36163610
b=P.slot_to_tid(temperature),
3617-
out=P.slot_to_tid(scaled),
3611+
out=P.slot_to_tid(logits_f),
36183612
)
36193613
)
3614+
scaled = logits_f
36203615

36213616
# top-p nucleus mask; SortNode is ascending-only, so sort -probs for descending.
3617+
# probs is read twice (neg_p below and the drop comparison), keep separate.
36223618
_, probs = P.make_tmp_slot()
36233619
P.emit(SoftmaxNode(x=P.slot_to_tid(scaled), out=P.slot_to_tid(probs), axis=-1))
3624-
_, neg_p = P.make_tmp_slot()
3625-
P.emit(NegNode(x=P.slot_to_tid(probs), out=P.slot_to_tid(neg_p)))
3626-
_, sorted_neg = P.make_tmp_slot()
3627-
P.emit(SortNode(x=P.slot_to_tid(neg_p), out=P.slot_to_tid(sorted_neg), axis=-1))
3620+
# neg_p -> sort -> neg are single-use; thread one buffer.
36283621
_, sorted_p = P.make_tmp_slot()
3629-
P.emit(NegNode(x=P.slot_to_tid(sorted_neg), out=P.slot_to_tid(sorted_p)))
3622+
P.emit(NegNode(x=P.slot_to_tid(probs), out=P.slot_to_tid(sorted_p)))
3623+
P.emit(
3624+
SortNode(x=P.slot_to_tid(sorted_p), out=P.slot_to_tid(sorted_p), axis=-1)
3625+
)
3626+
# sorted_p is read three times below (cumsum, prefix subtract, kept where),
3627+
# so stop reusing it here.
3628+
P.emit(NegNode(x=P.slot_to_tid(sorted_p), out=P.slot_to_tid(sorted_p)))
36303629
_, cum = P.make_tmp_slot()
36313630
P.emit(CumsumNode(x=P.slot_to_tid(sorted_p), out=P.slot_to_tid(cum), axis=-1))
3632-
_, prefix = P.make_tmp_slot()
3631+
# prefix = cum - sorted_p; cum is single-use, reuse it in place.
36333632
P.emit(
36343633
SubtractNode(
36353634
a=P.slot_to_tid(cum),
36363635
b=P.slot_to_tid(sorted_p),
3637-
out=P.slot_to_tid(prefix),
3636+
out=P.slot_to_tid(cum),
36383637
)
36393638
)
3639+
prefix = cum
36403640
# remove sorted tokens whose prefix mass already exceeds top_p (top-1: 0)
36413641
_, remove = P.make_tmp_slot()
36423642
P.emit(
@@ -3675,6 +3675,7 @@ def emit_sample():
36753675
)
36763676
)
36773677
neg_inf = emit_lifted_constant(P, float("-inf"), torch.float32)
3678+
# masked = where(drop, -inf, scaled); then add gumbel noise in place.
36783679
_, masked = P.make_tmp_slot()
36793680
P.emit(
36803681
WhereNode(
@@ -3684,16 +3685,14 @@ def emit_sample():
36843685
out=P.slot_to_tid(masked),
36853686
)
36863687
)
3687-
3688-
_, noisy = P.make_tmp_slot()
36893688
P.emit(
36903689
AddNode(
3691-
a=P.slot_to_tid(masked), b=P.slot_to_tid(g), out=P.slot_to_tid(noisy)
3690+
a=P.slot_to_tid(masked), b=P.slot_to_tid(g), out=P.slot_to_tid(masked)
36923691
)
36933692
)
36943693
P.emit(
36953694
ArgmaxNode(
3696-
x=P.slot_to_tid(noisy), out=P.slot_to_tid(out), axis=-1, keepdims=False
3695+
x=P.slot_to_tid(masked), out=P.slot_to_tid(out), axis=-1, keepdims=False
36973696
)
36983697
)
36993698

0 commit comments

Comments
 (0)