From cefcf727a2e21b2e7cfc1a42def997bd609a4a3b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 19 Mar 2025 15:17:55 +0530 Subject: [PATCH] fix: remove mutation of `BasicSymbolic` --- src/latexify_recipes.jl | 2 +- src/solver/ia_main.jl | 19 +++++++----- src/solver/polynomialization.jl | 53 +++++++++++++++++---------------- src/solver/preprocess.jl | 10 ++++--- src/solver/solve_helpers.jl | 3 +- 5 files changed, 48 insertions(+), 39 deletions(-) diff --git a/src/latexify_recipes.jl b/src/latexify_recipes.jl index 58dcdd3b4..7bc6f33e7 100644 --- a/src/latexify_recipes.jl +++ b/src/latexify_recipes.jl @@ -216,7 +216,7 @@ function _toexpr(O) while num isa Term && num.f isa Differential deg += 1 den *= num.f.x - num = num.arguments[1] + num = first(arguments(num)) end return :(_derivative($(_toexpr(num)), $den, $deg)) elseif op isa Integral diff --git a/src/solver/ia_main.jl b/src/solver/ia_main.jl index 7d00e06a1..39cfd0e48 100644 --- a/src/solver/ia_main.jl +++ b/src/solver/ia_main.jl @@ -33,7 +33,12 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri for i in eachindex(lhs_roots) for j in eachindex(rhs) if iscall(lhs_roots[i]) && operation(lhs_roots[i]) == RootsOf - lhs_roots[i].arguments[1] = substitute(lhs_roots[i].arguments[1], Dict(new_var=>rhs[j]), fold=false) + _args = copy(parent(arguments(lhs_roots[i]))) + _args[1] = substitute(_args[1], Dict(new_var => rhs[j]), fold = false) + T = typeof(lhs_roots[i]) + _op = operation(lhs_roots[i]) + _meta = metadata(lhs_roots[i]) + lhs_roots[i] = maketerm(T, _op, _args, _meta) push!(roots, lhs_roots[i]) else push!(roots, substitute(lhs_roots[i], Dict(new_var=>rhs[j]), fold=false)) @@ -86,8 +91,9 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri end elseif oper === (^) - if any(isequal(x, var) for x in get_variables(args[1])) && - n_occurrences(args[2], var) == 0 && args[2] isa Integer + var_in_base = any(isequal(x, var) for x in get_variables(args[1])) + var_in_pow = n_occurrences(args[2], var) != 0 + if var_in_base && !var_in_pow && args[2] isa Integer lhs = args[1] power = args[2] new_roots = [] @@ -111,11 +117,10 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri end rhs = [] append!(rhs, new_roots) - elseif any(isequal(x, var) for x in get_variables(args[1])) && - n_occurrences(args[2], var) == 0 + elseif var_in_base && !var_in_pow lhs = args[1] - s, args[2] = filter_stuff(args[2]) - rhs = map(sol -> term(^, sol, 1 // args[2]), rhs) + s, power = filter_stuff(args[2]) + rhs = map(sol -> term(^, sol, 1 // power), rhs) else lhs = args[2] rhs = map(sol -> term(/, term(slog, sol), term(slog, args[1])), rhs) diff --git a/src/solver/polynomialization.jl b/src/solver/polynomialization.jl index 346037c13..6b0c9e8cd 100644 --- a/src/solver/polynomialization.jl +++ b/src/solver/polynomialization.jl @@ -43,7 +43,7 @@ function turn_to_poly(expr, var) expr = unwrap(expr) !iscall(expr) && return (expr, Dict()) - args = arguments(expr) + args = copy(parent(arguments(expr))) sub = 0 broken = Ref(false) @@ -53,12 +53,12 @@ function turn_to_poly(expr, var) arg_oper = operation(arg) if arg_oper === (^) - tp = trav_pow(args, i, var, broken, sub) + args[i], tp = trav_pow(args[i], var, broken, sub) sub = isequal(tp, false) ? sub : tp continue end if arg_oper === (*) - sub = trav_mult(arg, var, broken, sub) + args[i], sub = trav_mult(arg, var, broken, sub) continue end isequal(add_sub(sub, arg, var, broken), false) && continue @@ -77,16 +77,17 @@ function turn_to_poly(expr, var) new_var = gensym() new_var = (@variables $new_var)[1] + expr = maketerm(typeof(expr), operation(expr), args, metadata(expr)) return ssubs(expr, Dict(sub => new_var)), Dict{Any, Any}(new_var => sub) end """ - trav_pow(args, index, var, broken, sub) + trav_pow(arg, var, broken, sub) -Traverses an argument passed from ``turn_to_poly`` if it -satisfies ``oper === (^)``. Returns sub if changed from 0 -to a new transcendental function or its value is -kept the same, and false if these 2 cases do not occur. +Traverses an argument `arg` passed from ``turn_to_poly`` if it satisfies +``oper === (^)``. Returns the new `arg` and `sub` if `sub` is changed from 0 to a new +transcendental function or its value is kept the same, or else `false` if these 2 cases +do not occur. # Arguments - args: The original arguments array of the expression passed to ``turn_to_poly`` @@ -97,20 +98,20 @@ kept the same, and false if these 2 cases do not occur. # Examples ```jldoctest -julia> trav_pow([unwrap(9^x)], 1, x, Ref(false), 3^x) -3^x +julia> trav_pow(unwrap(9^x), x, Ref(false), 3^x) +(9^x, 3^x) -julia> trav_pow([unwrap(x^2)], 1, x, Ref(false), 3^x) -false +julia> trav_pow(unwrap(x^2), x, Ref(false), 3^x) +(x^2, false) ``` """ -function trav_pow(args, index, var, broken, sub) - args_arg = arguments(args[index]) +function trav_pow(arg, var, broken, sub) + args_arg = arguments(arg) base = args_arg[1] power = args_arg[2] # case 1: log(x)^2 .... 9^x = 3^2^x = 3^2x = (3^x)^2 - !isequal(add_sub(sub, base, var, broken), false) && power isa Integer && return base + !isequal(add_sub(sub, base, var, broken), false) && power isa Integer && return arg, base # case 2: int^f(x) # n_func_occ may not be strictly 1, we could attempt attracting it after solving @@ -122,21 +123,20 @@ function trav_pow(args, index, var, broken, sub) sub = isequal(sub, 0) ? new_b : sub if !isequal(sub, new_b) broken[] = true - return false + return arg, false end new_b = term(^, new_b, p) - args[index] = new_b - return sub + return new_b, sub end - return false + return arg, false end """ trav_mult(arg, var, broken, sub) Traverses an argument passed from ``turn_to_poly`` if it -satisfies ``oper === (*)``. Returns sub whether its changed from 0 +satisfies ``oper === (*)``. Returns the new `arg` and `sub` if its changed from 0 to a new transcendental function or its value is kept the same, but changes broken if these 2 cases do not occur. It traverses the * argument by sub_arg and compares it to sub using @@ -151,24 +151,24 @@ the function ``add_sub`` # Examples ```jldoctest julia> trav_mult(unwrap(9*log(x)), x, Ref(false), log(x)) -log(x) +(9log(x), log(x)) julia> trav_mult(unwrap(9*log(x)^2), x, Ref(false), log(x)) -log(x) +(9(log(x)^2), log(x)) # value of broken is changed here to true julia> trav_mult(unwrap(9*log(x+1)), x, Ref(false), log(x)) -log(x) +(9log(x + 1), log(x)) ``` """ function trav_mult(arg, var, broken, sub) - args_arg = arguments(arg) + args_arg = copy(parent(arguments(arg))) for (i, arg2) in enumerate(args_arg) !iscall(arg2) && continue oper = operation(arg2) if oper === (^) - tp = trav_pow(args_arg, i, var, broken, sub) + args_arg[i], tp = trav_pow(args_arg[i], var, broken, sub) sub = isequal(tp, false) ? sub : tp continue end @@ -176,7 +176,8 @@ function trav_mult(arg, var, broken, sub) isequal(add_sub(sub, arg2, var, broken), false) && continue sub = arg2 end - return sub + arg = maketerm(typeof(arg), operation(arg), args_arg, metadata(arg)) + return arg, sub end """ diff --git a/src/solver/preprocess.jl b/src/solver/preprocess.jl index f506b434d..0e8949a36 100644 --- a/src/solver/preprocess.jl +++ b/src/solver/preprocess.jl @@ -116,7 +116,7 @@ function _filter_poly(expr, var) return filter_stuff(expr) end - args = arguments(expr) + args = copy(parent(arguments(expr))) if expr isa ComplexTerm subs1, subs2 = Dict(), Dict() expr1, expr2 = 0, 0 @@ -165,7 +165,7 @@ function _filter_poly(expr, var) end oper = operation(arg) - monomial = arguments(arg) + monomial = copy(parent(arguments(arg))) if oper === (^) if any(arg -> isequal(arg, var), monomial) continue @@ -175,6 +175,7 @@ function _filter_poly(expr, var) subs2, monomial[2] = _filter_poly(monomial[2], var) merge!(subs, merge(subs1, subs2)) + args[i] = maketerm(typeof(arg), oper, monomial, metadata(arg)) continue end @@ -196,6 +197,7 @@ function _filter_poly(expr, var) merge!(subs_of_monom, new_subs) end merge!(subs, subs_of_monom) + args[i] = maketerm(typeof(arg), oper, monomial, metadata(arg)) continue end @@ -208,9 +210,9 @@ function _filter_poly(expr, var) end end - args = map(unwrap, arguments(expr)) + args = map(unwrap, args) oper = operation(expr) - expr = term(oper, args...) + expr = maketerm(typeof(expr), oper, args, metadata(expr)) return subs, expr end diff --git a/src/solver/solve_helpers.jl b/src/solver/solve_helpers.jl index 13b70af22..f70ebdb05 100644 --- a/src/solver/solve_helpers.jl +++ b/src/solver/solve_helpers.jl @@ -116,10 +116,11 @@ function bigify(n) if n isa SymbolicUtils.BasicSymbolic !iscall(n) && return n - args = arguments(n) + args = copy(parent(arguments(n))) for i in eachindex(args) args[i] = bigify(args[i]) end + n = maketerm(typeof(n), operation(n), args, metadata(n)) return n end