diff --git a/src/MacroModelling.jl b/src/MacroModelling.jl index 7d542fef..90cd66a2 100644 --- a/src/MacroModelling.jl +++ b/src/MacroModelling.jl @@ -2172,12 +2172,13 @@ function convert_to_ss_equation(eq::Expr)::Expr end -function replace_indices_inside_for_loop(exxpr,index_variable,indices,concatenate, operator) - @assert operator ∈ [:+,:*] "Only :+ and :* allowed as operators in for loops." +function replace_indices_inside_for_loop(exxpr, index_variable, indices, concatenate, operator) + @assert operator ∈ [:+, :*] "Only :+ and :* allowed as operators in for loops." calls = [] indices = indices.args[1] == :(:) ? eval(indices) : [indices.args...] + index_syms = Symbol[i for i in indices if i isa Symbol] for idx in indices - push!(calls, postwalk(x -> begin + replaced = postwalk(x -> begin x isa Expr ? x.head == :ref ? @capture(x, name_{index_}[time_]) ? @@ -2205,16 +2206,42 @@ function replace_indices_inside_for_loop(exxpr,index_variable,indices,concatenat x : x : @capture(x, name_) ? - name == index_variable && idx isa Int ? - :($idx) : + name == index_variable ? + idx isa Int ? + :($idx) : + QuoteNode(idx) : x isa Symbol ? occursin("{" * string(index_variable) * "}", string(x)) ? - Symbol(replace(string(x), "{" * string(index_variable) * "}" => "{" * string(idx) * "}")) : + Symbol(replace(string(x), "{" * string(index_variable) * "}" => "{" * string(idx) * "}")) : + x in index_syms ? QuoteNode(x) : + x : x : - x : - x - end, - exxpr)) + x + end, exxpr) + + skip = false + if replaced isa Expr && replaced.head == :if + cond = replaced.args[1] + try + val = eval(cond) + if val + replaced = replaced.args[2] + elseif length(replaced.args) >= 3 + replaced = replaced.args[3] + else + skip = true + end + catch + end + elseif replaced isa Expr && replaced.head == :call && replaced.args[1] == :ifelse + cond = replaced.args[2] + try + val = eval(cond) + replaced = val ? replaced.args[3] : replaced.args[4] + catch + end + end + skip || push!(calls, replaced) end if concatenate diff --git a/test/models/for_if_no_else.jl b/test/models/for_if_no_else.jl new file mode 100644 index 00000000..9c5f3037 --- /dev/null +++ b/test/models/for_if_no_else.jl @@ -0,0 +1,7 @@ +@model IfNoElseLoop begin + for co in [H, F] + if co == H + x{co}[0] = 1 + end + end +end diff --git a/test/models/for_if_statement.jl b/test/models/for_if_statement.jl new file mode 100644 index 00000000..73830370 --- /dev/null +++ b/test/models/for_if_statement.jl @@ -0,0 +1,9 @@ +@model IfStatementLoop begin + for co in [H, F] + if co == H + x{co}[0] = 1 + else + x{co}[0] = 2 + end + end +end diff --git a/test/models/for_ifelse.jl b/test/models/for_ifelse.jl new file mode 100644 index 00000000..134bd799 --- /dev/null +++ b/test/models/for_ifelse.jl @@ -0,0 +1,5 @@ +@model IfElseLoop begin + for co in [H, F] + x{co}[0] = ifelse(co == H, 1, 2) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index f7960c4c..4949704a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,6 +20,24 @@ println("Threads used: ", Threads.nthreads()) include("functionality_tests.jl") +@testset "for loop ifelse" begin + include("models/for_ifelse.jl") + @test length(IfElseLoop.dyn_equations) == 2 + @test IfElseLoop.var == [:x◖H◗, :x◖F◗] +end + +@testset "for loop if" begin + include("models/for_if_statement.jl") + @test length(IfStatementLoop.dyn_equations) == 2 + @test IfStatementLoop.var == [:x◖H◗, :x◖F◗] +end + +@testset "for loop if no else" begin + include("models/for_if_no_else.jl") + @test length(IfNoElseLoop.dyn_equations) == 1 + @test IfNoElseLoop.var == [:x◖H◗] +end + # @testset verbose = true "Code formatting (JuliaFormatter.jl)" begin # @test format(MacroModelling; verbose=true, overwrite=true) # end