@@ -3,12 +3,22 @@ struct Substituter{Fold, D <: AbstractDict, F}
33 filter:: F
44end
55
6+ @inline function Substituter {Fold} (d:: AbstractDict , filter:: F ) where {Fold, F}
7+ Substituter {Fold, typeof(d), F} (d, filter)
8+ end
9+ @inline function Substituter {Fold} (d:: Pair , filter:: F ) where {Fold, F}
10+ Substituter {Fold} (Dict (d), filter)
11+ end
12+ @inline function Substituter {Fold} (d:: AbstractArray{<:Pair} , filter:: F ) where {Fold, F}
13+ Substituter {Fold} (Dict (d), filter)
14+ end
15+
616function (s:: Substituter )(ex)
717 return get (s. dict, ex, ex)
818end
919
1020function (s:: Substituter )(ex:: AbstractArray )
11- map (s, ex)
21+ [ s (x) for x in ex]
1222end
1323
1424function (s:: Substituter )(ex:: SparseMatrixCSC )
90100end
91101
92102"""
93- substitute(expr, dict; fold=true)
103+ substitute(expr, dict; fold=Val( true) )
94104
95105substitute any subexpression that matches a key in `dict` with
96- the corresponding value. If `fold=false`,
106+ the corresponding value. If `fold=Val( false) `,
97107expressions which can be evaluated won't be evaluated.
98108
99109```julia
100- julia> substitute(1+sqrt(y), Dict(y => 2), fold=true)
110+ julia> substitute(1+sqrt(y), Dict(y => 2), fold=Val( true) )
1011112.414213562373095
102- julia> substitute(1+sqrt(y), Dict(y => 2), fold=false)
112+ julia> substitute(1+sqrt(y), Dict(y => 2), fold=Val( false) )
1031131 + sqrt(2)
104114```
105115"""
106- @inline function substitute (expr, dict; fold= true , filterer= default_substitute_filter)
107- isempty (dict) && ! fold && return expr
108- return Substituter {fold, typeof(dict), typeof(filterer) } (dict, filterer)(expr)
116+ @inline function substitute (expr, dict; fold:: Val{Fold} = Val { true} () , filterer= default_substitute_filter) where {Fold}
117+ isempty (dict) && ! Fold && return expr
118+ return Substituter {Fold } (dict, filterer)(expr)
109119end
110120
111121"""
@@ -265,7 +275,7 @@ function _reduce_eliminated_idxs(expr::BasicSymbolic{T}, output_idx::OutIdxT{T},
265275 for (idx, ii) in zip (iidxs, collapsed)
266276 subrules[ii] = idx
267277 end
268- return substitute (new_expr, subrules; fold = false )
278+ return substitute (new_expr, subrules; fold = Val { false} () )
269279 end
270280end
271281@cache function reduce_eliminated_idxs_1 (expr:: BasicSymbolic{SymReal} , output_idx:: OutIdxT{SymReal} , ranges:: RangesT{SymReal} , reduce):: BasicSymbolic{SymReal}
@@ -335,9 +345,9 @@ function scalarize(x::BasicSymbolic{T}, ::Val{toplevel} = Val{false}()) where {T
335345 subrules[ii] = idxs[i]
336346 end
337347 if toplevel
338- substitute (new_expr, subrules; fold = true )
348+ substitute (new_expr, subrules; fold = Val { true} () )
339349 else
340- scalarize (substitute (new_expr, subrules; fold = true ))
350+ scalarize (substitute (new_expr, subrules; fold = Val { true} () ))
341351 end
342352 end
343353 end
0 commit comments