Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .githash
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ff81537a0c8e23806869eef5c28c235b0dc3fbbe
dec047f1bd1c8287513c6c437f946982e516ccd4
5 changes: 5 additions & 0 deletions docs/src/perf.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ Every intermediate `NDArray` (from a slice, broadcast, or function call) allocat
`@analyze_lifetimes` performs a **static last-use analysis** at macro-expansion time and inserts eager `maybe_insert_delete` calls immediately after each temporary's final use. Freed buffers are returned to cuNumeric's pool and recycled by the next same-sized allocation, skipping new buffer allocation.

```julia
T = Float32
A = cuNumeric.ones(T, (N, N))
B = cuNumeric.ones(T, (N, N))
C = cuNumeric.zeros(T, (N, N))

@analyze_lifetimes begin
result = A[1:end, :] .+ B[1:end, :]
C .= result .* 2.0
Expand Down
Binary file added gray-scott.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
114 changes: 79 additions & 35 deletions src/scoping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,18 @@ function insert_finalizers(exprs::Vector, assigned_vars::Set{Symbol})
defs = Dict{Symbol,Int}()
alias_map = Dict{Symbol,Symbol}()

# Collect all statements, flattening blocks and skipping LineNumberNodes
stmts = Any[]
for expr in exprs
append!(stmts, expr.args)
if expr isa LineNumberNode
continue
elseif expr isa Expr && expr.head == :block
for arg in expr.args
arg isa LineNumberNode || push!(stmts, arg)
end
else
push!(stmts, expr)
end
end

# Pass 1: collect definitions and uses
Expand Down Expand Up @@ -91,27 +100,45 @@ function insert_finalizers(exprs::Vector, assigned_vars::Set{Symbol})

# Pass 2: insert finalizers
out = Any[]
n = length(stmts)
for (i, stmt) in enumerate(stmts)
push!(out, stmt)
stmt isa Expr || continue
stmt.head == :line && continue

# detect aliasing: v = w means don't finalize w
skip_finalize = Set{Symbol}()
if stmt.head == :(=)
if stmt isa Expr && stmt.head == :(=)
lhs, rhs = stmt.args
# a = tmp1
# tmp1 will be added to skip_finalize
# a[:,:] = tmp1
# this does a copy, so we want to finalize tmp1
if lhs isa Symbol && rhs isa Symbol
push!(skip_finalize, rhs)
end
end

for (v, lasti) in last_use
if lasti == i && v ∈ assigned_vars && !(v ∈ skip_finalize)
push!(out, :(cuNumeric.maybe_insert_delete($v)))
if i == n
# Capture result of the last statement
res_var = Symbol(:res, counter[])
counter[] += 1
push!(out, :($res_var = $stmt))

# Insert finalizers for the last statement
for (v, lasti) in last_use
if lasti == i && v ∈ assigned_vars && !(v ∈ skip_finalize)
# Do not delete if the result of the block is exactly this variable
# or if it's an assignment to this variable.
is_result = (stmt === v)
if stmt isa Expr && stmt.head == :(=) && stmt.args[1] === v
is_result = true
end
if !is_result
push!(out, :(cuNumeric.maybe_insert_delete($v)))
end
end
end
# Return the captured result
push!(out, res_var)
else
push!(out, stmt)
for (v, lasti) in last_use
if lasti == i && v ∈ assigned_vars && !(v ∈ skip_finalize)
push!(out, :(cuNumeric.maybe_insert_delete($v)))
end
end
end
end
Expand All @@ -135,25 +162,10 @@ function insert_finalizers(block::Expr, assigned_vars::Set{Symbol})
end

function process_ndarray_scope(block)
# Normalize block to list of statements
stmts = block isa Expr && block.head == :block ? block.args : [block]

# Otherwise, process and cache
assigned_vars = Set{Symbol}()
body = Any[]

for stmt in stmts
stmts = find_ndarray_assignments(stmt, assigned_vars)
new_stmts = insert_finalizers(stmts, assigned_vars)
push!(body, new_stmts)
end

# println(body)

result = quote
$(Expr(:block, body...))
end

# Process the entire block at once so lifetimes are tracked across statements
rewritten = find_ndarray_assignments(block, assigned_vars)
result = insert_finalizers(rewritten, assigned_vars)
counter[] = 0
return result
end
Expand Down Expand Up @@ -184,7 +196,28 @@ function find_ndarray_assignments(ex, assigned_vars::Set{Symbol})
push!(local_assigned, lhs)
end
new_rhs, temps = rewrite(rhs)
return Expr(:block, temps..., :($lhs = $new_rhs)), []
return :($lhs = $new_rhs), temps
end

# --- broadcasted assignment: preserve fusion ---
if e.head == :(.=)
lhs, rhs = e.args
new_lhs, lhs_temps = rewrite(lhs)
# Do not hoist the top-level call of the RHS to preserve fusion
if rhs isa Expr && rhs.head == :call
op = rhs.args[1]
new_rhs_args, rhs_temps = Any[], Expr[]
for arg in rhs.args[2:end]
new_arg, t = rewrite(arg)
push!(new_rhs_args, new_arg)
append!(rhs_temps, t)
end
new_rhs = Expr(:call, op, new_rhs_args...)
return Expr(:(.=), new_lhs, new_rhs), vcat(lhs_temps, rhs_temps)
else
new_rhs, rhs_temps = rewrite(rhs)
return Expr(:(.=), new_lhs, new_rhs), vcat(lhs_temps, rhs_temps)
end
end

# --- array slice reference ---
Expand Down Expand Up @@ -214,15 +247,26 @@ function find_ndarray_assignments(ex, assigned_vars::Set{Symbol})

# --- fallback for other Expr types ---
new_args, hoisted = Any[], Expr[]
is_block = e.head == :block || e.head == :begin
for arg in e.args
new_arg, temps = rewrite(arg)
push!(new_args, new_arg)
append!(hoisted, temps)
if is_block && !(arg isa LineNumberNode)
append!(new_args, temps)
push!(new_args, new_arg)
else
push!(new_args, new_arg)
append!(hoisted, temps)
end
end
return Expr(e.head, new_args...), hoisted
end

new_ex, temps = rewrite(ex)
union!(assigned_vars, local_assigned)
return Expr(:block, temps..., new_ex)

if new_ex isa Expr && new_ex.head == :block
return Expr(:block, temps..., new_ex.args...)
else
return Expr(:block, temps..., new_ex)
end
end
25 changes: 25 additions & 0 deletions test/tests/scoping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,28 @@ const SLICE_OPS = Dict(
),
)

function test_scoping_regressions(T, N)
A = cuNumeric.ones(T, (N, N))
B = cuNumeric.ones(T, (N, N))
C = cuNumeric.zeros(T, (N, N))

@testset "In-place assignment" begin
@analyze_lifetimes begin
result = A[1:end, :] .+ B[1:end, :]
C .= result .* T(2.0)
end
# Test values: (1+1) * 2 = 4
@test all(Array(C) .== T(4.0))
end

@testset "Macro as RHS" begin
# Test values: (1+1)^2 = 4
res = @analyze_lifetimes (A .+ B) .^ 2
@test res isa cuNumeric.NDArray
@test all(Array(res) .== T(4.0))
end
end

function run_all_ops(FT, N)
results = Dict()

Expand All @@ -101,5 +123,8 @@ function run_all_ops(FT, N)
results[name] = (c_base, c_scoped)
end

# Regression tests
test_scoping_regressions(FT, N)

return results
end
Loading