diff --git a/README.md b/README.md index 2876a31..7e4dc0d 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [ci-img]: https://github.com/JuliaCollections/Memoize.jl/workflows/CI/badge.svg [ci-url]: https://github.com/JuliaCollections/Memoize.jl/actions -Easy memoization for Julia. +Easy method memoization for Julia. ## Usage @@ -23,15 +23,16 @@ julia> x(1) Running 2 -julia> memoize_cache(x) -IdDict{Any,Any} with 1 entry: - (1,) => 2 +julia> memories(x) +1-element Array{Any,1}: + IdDict{Tuple{Any},Any}((1,) => 2) julia> x(1) 2 -julia> empty!(memoize_cache(x)) -IdDict{Any,Any}() +julia> map(empty!, memories(x)) +1-element Array{IdDict{Tuple{Any},Any},1}: + IdDict() julia> x(1) Running @@ -41,22 +42,22 @@ julia> x(1) 2 ``` -By default, Memoize.jl uses an [`IdDict`](https://docs.julialang.org/en/v1/base/collections/#Base.IdDict) as a cache, but it's also possible to specify the type of the cache. If you want to cache vectors based on the values they contain, you probably want this: +By default, Memoize.jl uses an [`IdDict`](https://docs.julialang.org/en/v1/base/collections/#Base.IdDict) as a cache, but it's also possible to specify your own cache that supports the methods `Base.get!` and `Base.empty!`. If you want to cache vectors based on the values they contain, you probably want this: ```julia using Memoize -@memoize Dict function x(a) +@memoize Dict() function x(a) println("Running") a end ``` -You can also specify the full function call for constructing the dictionary. For example, to use LRUCache.jl: +You can also specify the full expression for constructing the cache. The variables `__Key__` and `__Value__` are available to the constructor expression, containing the syntactically determined type bounds on the keys and values used by Memoize.jl. For example, to use LRUCache.jl: ```julia using Memoize using LRUCache -@memoize LRU{Tuple{Any,Any},Any}(maxsize=2) function x(a, b) +@memoize LRU{__Key__,__Value__}(maxsize=2) function x(a, b) println("Running") a + b end @@ -86,12 +87,34 @@ julia> x(2,3) 5 ``` -## Notes +Memoize works on *almost* every method declaration in global and local scope, including lambdas and callable objects. When only the type of an argument is given, memoize caches the type. -Note that the `@memoize` macro treats the type argument differently depending on its syntactical form: in the expression -```julia -@memoize CacheType function x(a, b) - # ... +julia``` +struct F{A} + a::A +end +@memoize function (::F{A})(b, ::C) where {A, C} + println("Running") + (A, b, C) end ``` -the expression `CacheType` must be either a non-function-call that evaluates to a type, or a function call that evaluates to an _instance_ of the desired cache type. Either way, the methods `Base.get!` and `Base.empty!` must be defined for the supplied cache type. + +``` +julia> F(1)(1, 1) +Running +(Int64, 1, Int64) + +julia> F(1)(1, 2) +(Int64, 1, Int64) + +julia> F(1)(2, 2) +Running +(Int64, 2, Int64) + +julia> F(2)(2, 2) +(Int64, 2, Int64) + +julia> F(false)(2, 2) +Running +(Bool, 2, Int64) +``` diff --git a/src/Memoize.jl b/src/Memoize.jl index 6406c58..0ba5e0d 100644 --- a/src/Memoize.jl +++ b/src/Memoize.jl @@ -1,89 +1,211 @@ module Memoize -using MacroTools: isexpr, combinedef, namify, splitarg, splitdef -export @memoize, memoize_cache +using MacroTools: isexpr, combinearg, combinedef, namify, splitarg, splitdef, @capture +export @memoize, memories, memory -cache_name(f) = Symbol("##", f, "_memoized_cache") +# I would call which($sig) but it's only on 1.6 I think +function _which(tt, world = typemax(UInt)) + meth = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), tt, world) + if meth !== nothing + if meth isa Method + return meth::Method + else + meth = meth.func + return meth::Method + end + end +end + +""" + @memoize [cache] declaration + + Transform any method declaration `declaration` (except for inner constructors) so that calls to the original method are cached by their arguments. When an argument is unnamed, its type is treated as an argument instead. + + `cache` should be an expression which evaluates to a dictionary-like type that supports `get!` and `empty!`, and may depend on the local variables `__Key__` and `__Value__`, which evaluate to syntactically-determined bounds on the required key and value types the cache must support. + If the given cache contains values, it is assumed that they will agree with the values the method returns. Specializing a method will not empty the cache, but overwriting a method will. The caches corresponding to methods can be determined with `memory` or `memories.` +""" macro memoize(args...) if length(args) == 1 - dicttype = :(IdDict) + cache_constructor = :(IdDict{__Key__}{__Value__}()) ex = args[1] elseif length(args) == 2 - (dicttype, ex) = args + (cache_constructor, ex) = args else error("Memoize accepts at most two arguments") end - cache_dict = isexpr(dicttype, :call) ? dicttype : :(($dicttype)()) - - def_dict = try + def = try splitdef(ex) catch error("@memoize must be applied to a method definition") end + + function split(arg, iskwarg=false) + arg_name, arg_type, slurp, default = splitarg(arg) + trait = arg_name === nothing + trait && (arg_name = gensym()) + vararg = namify(arg_type) === :Vararg + return ( + arg_name = arg_name, + arg_type = arg_type, + arg_value = arg_name, + slurp = slurp, + vararg = vararg, + default = default, + trait = trait, + iskwarg = iskwarg) + end - # a return type declaration of Any is a No-op because everything is <: Any - rettype = get(def_dict, :rtype, Any) - f = def_dict[:name] - def_dict_unmemoized = copy(def_dict) - def_dict_unmemoized[:name] = u = Symbol("##", f, "_unmemoized") - - args = def_dict[:args] - kws = def_dict[:kwargs] - # Set up arguments for tuple - tup = [splitarg(arg)[1] for arg in vcat(args, kws)] - - # Set up identity arguments to pass to unmemoized function - identargs = map(args) do arg - arg_name, typ, slurp, default = splitarg(arg) - if slurp || namify(typ) === :Vararg - Expr(:..., arg_name) - else - arg_name + combine(arg) = combinearg(arg.arg_name, arg.arg_type, arg.slurp, arg.default) + + pass(arg) = + (arg.slurp || arg.vararg) ? Expr(:..., arg.arg_name) : + arg.iskwarg ? Expr(:kw, arg.arg_name, arg.arg_name) : arg.arg_name + + dispatch(arg) = arg.slurp ? :(Vararg{$(arg.arg_type)}) : arg.arg_type + + args = split.(def[:args]) + kwargs = split.(def[:kwargs], true) + def[:args] = combine.(args) + def[:kwargs] = combine.(kwargs) + @gensym inner + inner_def = deepcopy(def) + inner_def[:name] = inner + inner_args = copy(args) + inner_kwargs = copy(kwargs) + pop!(inner_def, :params, nothing) + + @gensym result + + # If this is a method of a callable type or object, the definition returns nothing. + # Thus, we must construct the type of the method on our own. + # We also need to pass the object to the inner function + if haskey(def, :name) + if haskey(def, :params) # Callable type + typ = :($(def[:name]){$(pop!(def, :params)...)}) + inner_args = [split(:(::Type{$typ})), inner_args...] + def[:name] = combine(inner_args[1]) + head = :(Type{$typ}) + elseif @capture(def[:name], obj_::obj_type_ | ::obj_type_) # Callable object + inner_args = [split(def[:name]), inner_args...] + def[:name] = combine(inner_args[1]) + head = obj_type + else # Normal call + head = :(typeof($(def[:name]))) end + else # Anonymous function + head = :(typeof($result)) end - identkws = map(kws) do kw - arg_name, typ, slurp, default = splitarg(kw) - if slurp - Expr(:..., arg_name) - else - Expr(:kw, arg_name, arg_name) - end + inner_def[:args] = combine.(inner_args) + + # Set up arguments for memo key + key_names = map([inner_args; inner_kwargs]) do arg + arg.trait ? arg.arg_type : arg.arg_name + end + key_types = map([inner_args; inner_kwargs]) do arg + arg.trait ? DataType : + arg.vararg ? :(Tuple{$(arg.arg_type)}) : + arg.arg_type end - fcachename = cache_name(f) - mod = __module__ - fcache = isdefined(mod, fcachename) ? - getfield(mod, fcachename) : - Core.eval(mod, :(const $fcachename = $cache_dict)) + @gensym cache - body = quote - get!($fcache, ($(tup...),)) do - $u($(identargs...); $(identkws...)) + pass_args = pass.(inner_args) + pass_kwargs = pass.(inner_kwargs) + def[:body] = quote + $(combinedef(inner_def)) + get!($cache, ($(key_names...),)) do + $inner($(pass_args...); $(pass_kwargs...)) end end - if length(kws) == 0 - def_dict[:body] = quote - $(body)::Core.Compiler.return_type($u, typeof(($(identargs...),))) + # A return type declaration of Any is a No-op because everything is <: Any + return_type = get(def, :rtype, Any) + + if length(kwargs) == 0 + def[:body] = quote + $(def[:body])::Core.Compiler.widenconst(Core.Compiler.return_type($inner, typeof(($(pass_args...),)))) end - else - def_dict[:body] = body end - esc(quote - $(combinedef(def_dict_unmemoized)) - empty!($fcache) - Base.@__doc__ $(combinedef(def_dict)) - end) + @gensym world + @gensym old_meth + @gensym meth + @gensym brain + @gensym old_brain + + sig = :(Tuple{$head, $(dispatch.(args)...)} where {$(def[:whereparams]...)}) + + return esc(quote + # The `local` qualifier will make this performant even in the global scope. + local $cache = begin + local __Key__ = (Tuple{$(key_types...)} where {$(def[:whereparams]...)}) + local __Value__ = ($return_type where {$(def[:whereparams]...)}) + $cache_constructor + end + + local $world = Base.get_world_counter() + + local $result = Base.@__doc__($(combinedef(def))) + + local $brain = if isdefined($__module__, :__Memoize_brain__) + brain = getfield($__module__, :__Memoize_brain__) + else + global __Memoize_brain__ = Dict() + end + + # If overwriting a method, empty the old cache. + # Notice that methods are hashed by their stored signature + local $old_meth = $_which($sig, $world) + if $old_meth !== nothing && $old_meth.sig == $sig + if isdefined($old_meth.module, :__Memoize_brain__) + $old_brain = getfield($old_meth.module, :__Memoize_brain__) + empty!(pop!($old_brain, $old_meth.sig, [])) + end + end + + # Store the cache so that it can be emptied later + local $meth = $_which($sig) + @assert $meth !== nothing + $brain[$meth.sig] = $cache + $result + end) end -function memoize_cache(f::Function) - # This will fail in certain circumstances (eg. @memoize Base.sin(::MyNumberType) = ...) but I don't think there's - # a clean answer here, because we can already have multiple caches for certain functions, if the methods are - # defined in different modules. - getproperty(parentmodule(f), cache_name(f)) +""" + memories(f, [types], [module]) + + Return an array of memoized method caches for the function f. + + This function takes the same arguments as the method methods. +""" +memories(f, args...) = _memories(methods(f, args...)) + +function _memories(ms::Base.MethodList) + memories = [] + for m in ms + if isdefined(m.module, :__Memoize_brain__) + brain = getfield(m.module, :__Memoize_brain__) + memory = get(brain, m.sig, nothing) + if memory !== nothing + push!(memories, memory) + end + end + end + return memories end +""" + memory(m) + + Return the memoized cache for the method m, or nothing if no such method exists +""" +function memory(m::Method) + if isdefined(m.module, :__Memoize_brain__) + brain = getfield(m.module, :__Memoize_brain__) + return get(brain, m.sig, nothing) + end end + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index d9639ed..2d2c5cf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,12 +29,32 @@ end @test simple(6) == 6 @test run == 2 -empty!(memoize_cache(simple)) +map(empty!, memories(simple)) @test simple(6) == 6 @test run == 3 @test simple(6) == 6 @test run == 3 +run = 0 +lambda = @memoize (a) -> begin + global run += 1 + a +end +@test lambda(5) == 5 +@test run == 1 +@test lambda(5) == 5 +@test run == 1 +@test lambda(6) == 6 +@test run == 2 +@test lambda(6) == 6 +@test run == 2 + +map(empty!, memories(lambda)) +@test lambda(6) == 6 +@test run == 3 +@test lambda(6) == 6 +@test run == 3 + run = 0 @memoize function typed(a::Int) global run += 1 @@ -183,7 +203,7 @@ end @test run == 2 run = 0 -@memoize Dict function kw_ellipsis(;a...) +@memoize Dict() function kw_ellipsis(;a...) global run += 1 a end @@ -254,6 +274,137 @@ end outer() @test !@isdefined inner +trait_function(a, ::Bool) = (-a,) +run = 0 +@memoize function trait_function(a, ::Int) + global run += 1 + (a,) +end +@test trait_function(1, true) == (-1,) +@test run == 0 +@test trait_function(2, true) == (-2,) +@test run == 0 +@test trait_function(1, 1) == (1,) +@test run == 1 +@test trait_function(1, 2) == (1,) +@test run == 1 +@test trait_function(2, 2) == (2,) +@test run == 2 +@test trait_function(2, 2) == (2,) +@test run == 2 + +run = 0 +@memoize function trait_params(a, ::T) where {T} + global run += 1 + (a, T) +end +@test trait_params(1, true) == (1, Bool) +@test run == 1 +@test trait_params(1, false) == (1, Bool) +@test run == 1 +@test trait_params(2, true) == (2, Bool) +@test run == 2 +@test trait_params(2, false) == (2, Bool) +@test run == 2 +@test trait_params(1, 3) == (1, Int) +@test run == 3 +@test trait_params(1, 4) == (1, Int) +@test run == 3 + +run = 0 +struct callable_object + a +end +@memoize function (o::callable_object)(b) + global run += 1 + (o.a, b) +end +@test callable_object(1)(2) == (1, 2) +@test run == 1 +@test callable_object(1)(2) == (1, 2) +@test run == 1 +@test callable_object(1)(3) == (1, 3) +@test run == 2 +@test callable_object(1)(3) == (1, 3) +@test run == 2 +@test callable_object(2)(3) == (2, 3) +@test run == 3 +@test callable_object(2)(3) == (2, 3) +@test run == 3 + +run = 0 +struct callable_trait_object{T} + a::T +end +@memoize function (::callable_trait_object{T})(b) where {T} + global run += 1 + (T, b) +end +@test callable_trait_object(1)(2) == (Int, 2) +@test run == 1 +@test callable_trait_object(2)(2) == (Int, 2) +@test run == 1 +@test callable_trait_object(false)(2) == (Bool, 2) +@test run == 2 +@test callable_trait_object(true)(3) == (Bool, 3) +@test run == 3 +@test callable_trait_object(1)(3) == (Int, 3) +@test run == 4 +@test callable_trait_object(2)(3) == (Int, 3) +@test run == 4 + +run = 0 +struct callable_type{T} + a::T +end +@memoize function callable_type{T}(b) where {T} + global run += 1 + (T, b) +end +@test callable_type{Int}(2) == (Int, 2) +@test run == 1 +@test callable_type{Int}(2) == (Int, 2) +@test run == 1 +@test callable_type{Int}(3) == (Int, 3) +@test run == 2 +@test callable_type{Int}(3) == (Int, 3) +@test run == 2 +@test callable_type{Bool}(3) == (Bool, 3) +@test run == 3 +@test callable_type{Bool}(3) == (Bool, 3) +@test run == 3 + +genrun = 0 +@memoize function genspec(a) + global genrun += 1 + a + 1 +end +specrun = 0 +@test genspec(5) == 6 +@test genrun == 1 +@test specrun == 0 +@memoize function genspec(a::Int) + global specrun += 1 + a + 2 +end +@test genspec(5) == 7 +@test genrun == 1 +@test specrun == 1 +@test genspec(5) == 7 +@test genrun == 1 +@test specrun == 1 +@test genspec(true) == 2 +@test genrun == 2 +@test specrun == 1 +@test invoke(genspec, Tuple{Any}, 5) == 6 +@test genrun == 2 +@test specrun == 1 + +map(empty!, memories(genspec, Tuple{Int})) +@test genspec(5) == 7 +@test genrun == 2 +@test specrun == 2 + @memoize function typeinf(x) x + 1 end @@ -267,6 +418,9 @@ finalized = false x end method_rewrite() +@memoize function method_rewrite(x) end +GC.gc() +@test !finalized @memoize function method_rewrite() end GC.gc() @test finalized @@ -308,7 +462,7 @@ using Memoize const MyDict = Dict run = 0 -@memoize MyDict function custom_dict(a) +@memoize MyDict() function custom_dict(a) global run += 1 a end @@ -328,13 +482,13 @@ end # module using .MemoizeTest using .MemoizeTest: custom_dict -empty!(memoize_cache(custom_dict)) +map(empty!, memories(custom_dict)) @test custom_dict(1) == 1 @test MemoizeTest.run == 3 @test custom_dict(1) == 1 @test MemoizeTest.run == 3 -empty!(memoize_cache(MemoizeTest.custom_dict)) +map(empty!, memories(MemoizeTest.custom_dict)) @test custom_dict(1) == 1 @test MemoizeTest.run == 4 @@ -350,5 +504,21 @@ end @test dict_call("bb") == 2 @test run == 2 @test dict_call("bb") == 2 + +run = 0 +@memoize Dict{__Key__,__Value__}() function auto_dict_call(a::String)::Int + global run += 1 + length(a) +end +@test auto_dict_call("a") == 1 +@test run == 1 +@test auto_dict_call("a") == 1 +@test run == 1 +@test auto_dict_call("bb") == 2 +@test run == 2 +@test auto_dict_call("bb") == 2 @test run == 2 +@test memories(auto_dict_call)[1] isa Dict{Tuple{String}, Int} +@memoize non_allocating(x) = x+1 +@test @allocated(non_allocating(10)) == 0