Skip to content

Commit 7f40c18

Browse files
fix: fix inference of substitute
1 parent c88cf52 commit 7f40c18

File tree

4 files changed

+43
-18
lines changed

4 files changed

+43
-18
lines changed

benchmark/benchmarks.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,21 @@ let r = @rule(~x => ~x), rs = RuleSet([r]),
5858

5959
# we use `fold = false` since otherwise it dynamic dispatches to `sin`/`cos` whenever
6060
# both arguments in the contained addition are substituted.
61-
overhead["substitute"]["a"] = @benchmarkable substitute(subs_expr, $(Dict(a=>1)); fold = false) setup=begin
61+
62+
fold = @static if hasmethod(SymbolicUtils.Substituter{true}, Tuple{Dict, Function})
63+
Val{false}()
64+
else
65+
false
66+
end
67+
overhead["substitute"]["a"] = @static @benchmarkable substitute(subs_expr, $(Dict(a=>1)); fold) setup=begin
6268
subs_expr = (sin(a+b) + cos(b+c)) * (sin(b+c) + cos(c+a)) * (sin(c+a) + cos(a+b))
6369
end
6470

65-
overhead["substitute"]["a,b"] = @benchmarkable substitute(subs_expr, $(Dict(a=>1, b=>2)); fold = false) setup=begin
71+
overhead["substitute"]["a,b"] = @benchmarkable substitute(subs_expr, $(Dict(a=>1, b=>2)); fold) setup=begin
6672
subs_expr = (sin(a+b) + cos(b+c)) * (sin(b+c) + cos(c+a)) * (sin(c+a) + cos(a+b))
6773
end
6874

69-
overhead["substitute"]["a,b,c"] = @benchmarkable substitute(subs_expr, $(Dict(a=>1, b=>2, c=>3)); fold = false) setup=begin
75+
overhead["substitute"]["a,b,c"] = @benchmarkable substitute(subs_expr, $(Dict(a=>1, b=>2, c=>3)); fold) setup=begin
7076
subs_expr = (sin(a+b) + cos(b+c)) * (sin(b+c) + cos(c+a)) * (sin(c+a) + cos(a+b))
7177
end
7278

src/SymbolicUtils.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ include("substitute.jl")
155155
include("code.jl")
156156

157157
PrecompileTools.@setup_workload begin
158+
fold1 = Val{false}()
159+
fold2 = Val{true}()
158160
PrecompileTools.@compile_workload begin
159161
@syms x y f(t) q[1:5]
160162
Sym{SymReal}(:a; type = Real, shape = ShapeVecT())
@@ -173,8 +175,15 @@ PrecompileTools.@setup_workload begin
173175
show(devnull, x ^ 2 + y * x + y / 3x)
174176
expand((x + y) ^ 2)
175177
simplify(x ^ (1//2) + (sin(x) ^ 2 + cos(x) ^ 2) + 2(x + y) - x - y)
176-
substitute(x + 2y + sin(x), Dict(x => y); fold = false)
177-
substitute(x + 2y + sin(x), Dict(x => 1); fold = true)
178+
ex = x + 2y + sin(x)
179+
rules1 = Dict(x => y)
180+
rules2 = Dict(x => 1)
181+
substitute(ex, rules1)
182+
substitute(ex, rules1; fold = fold1)
183+
substitute(ex, rules2; fold = fold1)
184+
substitute(ex, rules2)
185+
substitute(ex, rules1; fold = fold2)
186+
substitute(ex, rules2; fold = fold2)
178187
q[1]
179188
q'q
180189
end

src/substitute.jl

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,22 @@ struct Substituter{Fold, D <: AbstractDict, F}
33
filter::F
44
end
55

6+
@inline function Substituter{Fold}(d::AbstractDict, filter::F) where {Fold, F}
7+
Substituter{Fold, typeof(d), F}(d, filter)
8+
end
9+
@inline function Substituter{Fold}(d::Pair, filter::F) where {Fold, F}
10+
Substituter{Fold}(Dict(d), filter)
11+
end
12+
@inline function Substituter{Fold}(d::AbstractArray{<:Pair}, filter::F) where {Fold, F}
13+
Substituter{Fold}(Dict(d), filter)
14+
end
15+
616
function (s::Substituter)(ex)
717
return get(s.dict, ex, ex)
818
end
919

1020
function (s::Substituter)(ex::AbstractArray)
11-
map(s, ex)
21+
[s(x) for x in ex]
1222
end
1323

1424
function (s::Substituter)(ex::SparseMatrixCSC)
@@ -90,22 +100,22 @@ end
90100
end
91101

92102
"""
93-
substitute(expr, dict; fold=true)
103+
substitute(expr, dict; fold=Val(true))
94104
95105
substitute any subexpression that matches a key in `dict` with
96-
the corresponding value. If `fold=false`,
106+
the corresponding value. If `fold=Val(false)`,
97107
expressions which can be evaluated won't be evaluated.
98108
99109
```julia
100-
julia> substitute(1+sqrt(y), Dict(y => 2), fold=true)
110+
julia> substitute(1+sqrt(y), Dict(y => 2), fold=Val(true))
101111
2.414213562373095
102-
julia> substitute(1+sqrt(y), Dict(y => 2), fold=false)
112+
julia> substitute(1+sqrt(y), Dict(y => 2), fold=Val(false))
103113
1 + sqrt(2)
104114
```
105115
"""
106-
@inline function substitute(expr, dict; fold=true, filterer=default_substitute_filter)
107-
isempty(dict) && !fold && return expr
108-
return Substituter{fold, typeof(dict), typeof(filterer)}(dict, filterer)(expr)
116+
@inline function substitute(expr, dict; fold::Val{Fold}=Val{true}(), filterer=default_substitute_filter) where {Fold}
117+
isempty(dict) && !Fold && return expr
118+
return Substituter{Fold}(dict, filterer)(expr)
109119
end
110120

111121
"""
@@ -265,7 +275,7 @@ function _reduce_eliminated_idxs(expr::BasicSymbolic{T}, output_idx::OutIdxT{T},
265275
for (idx, ii) in zip(iidxs, collapsed)
266276
subrules[ii] = idx
267277
end
268-
return substitute(new_expr, subrules; fold = false)
278+
return substitute(new_expr, subrules; fold = Val{false}())
269279
end
270280
end
271281
@cache function reduce_eliminated_idxs_1(expr::BasicSymbolic{SymReal}, output_idx::OutIdxT{SymReal}, ranges::RangesT{SymReal}, reduce)::BasicSymbolic{SymReal}
@@ -335,9 +345,9 @@ function scalarize(x::BasicSymbolic{T}, ::Val{toplevel} = Val{false}()) where {T
335345
subrules[ii] = idxs[i]
336346
end
337347
if toplevel
338-
substitute(new_expr, subrules; fold = true)
348+
substitute(new_expr, subrules; fold = Val{true}())
339349
else
340-
scalarize(substitute(new_expr, subrules; fold = true))
350+
scalarize(substitute(new_expr, subrules; fold = Val{true}()))
341351
end
342352
end
343353
end

src/types.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3155,10 +3155,10 @@ Base.@propagate_inbounds function _getindex(arr::BasicSymbolic{T}, idxs::Union{B
31553155
end
31563156
if isempty(new_output_idx)
31573157
new_expr = reduce_eliminated_idxs(expr, output_idx, ranges, reduce)
3158-
result = substitute(new_expr, subrules; fold = false, filterer = !isarrayop)
3158+
result = substitute(new_expr, subrules; fold = Val{false}(), filterer = !isarrayop)
31593159
return result
31603160
else
3161-
new_expr = substitute(expr, subrules; fold = false, filterer = !isarrayop)
3161+
new_expr = substitute(expr, subrules; fold = Val{false}(), filterer = !isarrayop)
31623162
if term !== nothing
31633163
term = getindex(term, idxs...)
31643164
end

0 commit comments

Comments
 (0)