Skip to content

Commit 5936cbe

Browse files
committed
Support @check_allocs at callsites
1 parent 48a8c34 commit 5936cbe

File tree

1 file changed

+51
-15
lines changed

1 file changed

+51
-15
lines changed

Diff for: src/macro.jl

+51-15
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,22 @@ end
2525

2626
"""
2727
@check_allocs ignore_throw=true (function def)
28+
@check_allocs ignore_throw=true func(...)
2829
2930
Wraps the provided function definition so that all calls to it will be automatically
3031
checked for allocations.
3132
3233
If the check fails, an `AllocCheckFailure` exception is thrown containing the detailed
3334
failures, including the backtrace for each defect.
3435
35-
Note: All calls to the wrapped function are effectively a dynamic dispatch, which
36-
means they are type-unstable and may allocate memory at function _entry_. `@check_allocs`
37-
only guarantees the absence of allocations after the function has started running.
36+
`@check_allocs` can also be applied to a function call, which operates by creating
37+
an anonymous function that is passed to `@check_allocs` and then immediately calling
38+
the wrapped result.
39+
40+
!!! note
41+
All calls to the wrapped function are effectively a dynamic dispatch, which
42+
means they are type-unstable and may allocate memory at function _entry_. `@check_allocs`
43+
only guarantees the absence of allocations after the function has started running.
3844
3945
# Example
4046
```jldoctest
@@ -45,23 +51,27 @@ julia> multiply(1.5, 3.5) # no allocations for Float64
4551
5.25
4652
4753
julia> multiply(rand(3,3), rand(3,3)) # matmul needs to allocate the result
48-
ERROR: @check_alloc function contains 1 allocations.
49-
54+
ERROR: @check_alloc function contains 1 allocations (1 allocations / 0 dynamic dispatches).
5055
Stacktrace:
5156
[1] macro expansion
52-
@ ~/repos/AllocCheck/src/macro.jl:134 [inlined]
57+
@ ~/.julia/dev/AllocCheck/src/macro.jl:157 [inlined]
5358
[2] multiply(x::Matrix{Float64}, y::Matrix{Float64})
54-
@ Main ./REPL[2]:133
59+
@ Main ./REPL[2]:156
5560
[3] top-level scope
56-
@ REPL[5]:1
61+
@ REPL[4]:1
62+
63+
julia> @check_allocs 1.5 * 3.5 # check a call
64+
5.25
5765
```
5866
"""
5967
macro check_allocs(ex...)
6068
kws, body = extract_keywords(ex)
6169
if _is_func_def(body)
62-
return _check_allocs_macro(body, __module__, __source__; kws...)
70+
return _check_allocs_defun(body, __module__, __source__; kws...)
71+
elseif Meta.isexpr(body, :call)
72+
return _check_allocs_call(body, __module__, __source__; kws...)
6373
else
64-
error("@check_allocs used on something other than a function definition")
74+
error("@check_allocs used on anything other than a function definition or call")
6575
end
6676
end
6777

@@ -117,13 +127,20 @@ function forward_args!(func_def)
117127
args, kwargs
118128
end
119129

120-
function _check_allocs_macro(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true)
130+
function _check_allocs_defun(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true)
131+
(; original_fn, f_sym, wrapper_fn) = _check_allocs_wrap_fn(ex, mod, source; ignore_throw)
132+
quote
133+
local $f_sym = $(esc(original_fn))
134+
$wrapper_fn
135+
end
136+
end
121137

138+
function _check_allocs_wrap_fn(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true)
122139
# Transform original function to a renamed version with flattened args
123140
def = splitdef(deepcopy(ex))
124141
normalize_args!(def)
125142
original_fn = combinedef(def)
126-
f_sym = haskey(def, :name) ? gensym(def[:name]) : gensym()
143+
f_sym = haskey(def, :name) ? gensym(def[:name]) : gensym("fn_alias")
127144

128145
# Next, create a wrapper function that will compile the original function on-the-fly.
129146
def = splitdef(ex)
@@ -149,8 +166,27 @@ function _check_allocs_macro(ex::Expr, mod::Module, source::LineNumberNode; igno
149166
def[:body].args[1] = source
150167

151168
wrapper_fn = combinedef(def)
152-
return quote
153-
local $f_sym = $(esc(original_fn))
154-
$(wrapper_fn)
169+
170+
(; original_fn, f_sym, wrapper_fn)
171+
end
172+
173+
function _check_allocs_call(ex::Expr, mod::Module, source::LineNumberNode; ignore_throw=true)
174+
fn = first(ex.args)
175+
args = ex.args[2:end]
176+
args_template = if !isempty(args) && Meta.isexpr(first(args), :parameters)
177+
kwargs = Expr(:parameters, map(a -> if Meta.isexpr(a, :kw) first(a.args) else a end::Symbol, popfirst!(args).args)...)
178+
[kwargs, map(_ -> gensym("arg"), 1:length(args))...]
179+
else
180+
[map(_ -> gensym("arg"), 1:length(args))...]
181+
end
182+
passthrough_defun = Expr(:function, Expr(:tuple, args_template...), Expr(:call, fn, args_template...))
183+
(; f_sym, wrapper_fn) = _check_allocs_wrap_fn(passthrough_defun, mod, source; ignore_throw)
184+
passthrough_defun_esc = Expr(:function, Expr(:tuple, map(esc, args_template)...), Expr(:call, fn, map(esc, args_template)...))
185+
af_sym = gensym("alloccheck_fn")
186+
quote
187+
let $f_sym = $passthrough_defun_esc
188+
$af_sym = $wrapper_fn
189+
$(Expr(:call, af_sym, map(esc, args)...))
190+
end
155191
end
156192
end

0 commit comments

Comments
 (0)