Skip to content

Commit 8432a0a

Browse files
authored
allow single config in solver (#92)
1 parent 4759afb commit 8432a0a

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

src/interfaces.jl

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -489,16 +489,33 @@ size_all_negative(::SpinGlass) = false
489489
size_all_positive(::SpinGlass) = false
490490

491491
# NOTE: `findmin` and `findmax` are required by `ProblemReductions.jl`
492+
"""
493+
GTNSolver(; optimizer=TreeSA(), single=false, usecuda=false, T=Float64)
494+
495+
A generic tensor network based backend for the `findbest`, `findmin` and `findmax` interfaces in `ProblemReductions.jl`.
496+
497+
Keyword arguments
498+
-------------------------------------
499+
* `optimizer` is the optimizer for the tensor network contraction.
500+
* `single` is a switch to return single solution instead of all solutions.
501+
* `usecuda` is a switch to use CUDA (when applicable), user need to call statement `using CUDA` before turning on this switch.
502+
* `T` is the "base" element type, sometimes can be used to reduce the memory cost.
503+
"""
492504
Base.@kwdef struct GTNSolver
493505
optimizer::OMEinsum.CodeOptimizer = TreeSA()
506+
single::Bool = false
494507
usecuda::Bool = false
495508
T::Type = Float64
496509
end
497-
function Base.findmin(problem::AbstractProblem, solver::GTNSolver)
498-
res = collect(solve(GenericTensorNetwork(problem; optimizer=solver.optimizer), ConfigsMin(; tree_storage=true); usecuda=solver.usecuda, T=solver.T)[].c)
499-
return map(x -> ProblemReductions.id_to_config(problem, Int.(x) .+ 1), res)
500-
end
501-
function Base.findmax(problem::AbstractProblem, solver::GTNSolver)
502-
res = collect(solve(GenericTensorNetwork(problem; optimizer=solver.optimizer), ConfigsMax(; tree_storage=true); usecuda=solver.usecuda, T=solver.T)[].c)
503-
return map(x -> ProblemReductions.id_to_config(problem, Int.(x) .+ 1), res)
504-
end
510+
for (PROP, SPROP, SOLVER) in [
511+
(:ConfigsMin, :SingleConfigMin, :findmin), (:ConfigsMax, :SingleConfigMax, :findmax)
512+
]
513+
@eval function Base.$(SOLVER)(problem::AbstractProblem, solver::GTNSolver)
514+
if solver.single
515+
res = [solve(GenericTensorNetwork(problem; optimizer=solver.optimizer), $(SPROP)(); usecuda=solver.usecuda, T=solver.T)[].c.data]
516+
else
517+
res = collect(solve(GenericTensorNetwork(problem; optimizer=solver.optimizer), $(PROP)(; tree_storage=true); usecuda=solver.usecuda, T=solver.T)[].c)
518+
end
519+
return map(x -> ProblemReductions.id_to_config(problem, Int.(x) .+ 1), res)
520+
end
521+
end

test/interfaces.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,4 +254,8 @@ end
254254
solver2 = BruteForce()
255255
@test Set(findmin(sg, solver1)) == Set(findmin(sg, solver2))
256256
@test Set(findmax(sg, solver1)) == Set(findmax(sg, solver2))
257+
258+
solver3 = GTNSolver(; optimizer=TreeSA(ntrials=1), single=true)
259+
@test findmin(sg, solver3)[] findmin(sg, solver2)
260+
@test findmax(sg, solver3)[] findmax(sg, solver2)
257261
end

0 commit comments

Comments
 (0)