@@ -489,16 +489,33 @@ size_all_negative(::SpinGlass) = false
489
489
size_all_positive (:: SpinGlass ) = false
490
490
491
491
# NOTE: `findmin` and `findmax` are required by `ProblemReductions.jl`
492
+ """
493
+ GTNSolver(; optimizer=TreeSA(), single=false, usecuda=false, T=Float64)
494
+
495
+ A generic tensor network based backend for the `findbest`, `findmin` and `findmax` interfaces in `ProblemReductions.jl`.
496
+
497
+ Keyword arguments
498
+ -------------------------------------
499
+ * `optimizer` is the optimizer for the tensor network contraction.
500
+ * `single` is a switch to return single solution instead of all solutions.
501
+ * `usecuda` is a switch to use CUDA (when applicable), user need to call statement `using CUDA` before turning on this switch.
502
+ * `T` is the "base" element type, sometimes can be used to reduce the memory cost.
503
+ """
492
504
Base. @kwdef struct GTNSolver
493
505
optimizer:: OMEinsum.CodeOptimizer = TreeSA ()
506
+ single:: Bool = false
494
507
usecuda:: Bool = false
495
508
T:: Type = Float64
496
509
end
497
- function Base. findmin (problem:: AbstractProblem , solver:: GTNSolver )
498
- res = collect (solve (GenericTensorNetwork (problem; optimizer= solver. optimizer), ConfigsMin (; tree_storage= true ); usecuda= solver. usecuda, T= solver. T)[]. c)
499
- return map (x -> ProblemReductions. id_to_config (problem, Int .(x) .+ 1 ), res)
500
- end
501
- function Base. findmax (problem:: AbstractProblem , solver:: GTNSolver )
502
- res = collect (solve (GenericTensorNetwork (problem; optimizer= solver. optimizer), ConfigsMax (; tree_storage= true ); usecuda= solver. usecuda, T= solver. T)[]. c)
503
- return map (x -> ProblemReductions. id_to_config (problem, Int .(x) .+ 1 ), res)
504
- end
510
+ for (PROP, SPROP, SOLVER) in [
511
+ (:ConfigsMin , :SingleConfigMin , :findmin ), (:ConfigsMax , :SingleConfigMax , :findmax )
512
+ ]
513
+ @eval function Base. $ (SOLVER)(problem:: AbstractProblem , solver:: GTNSolver )
514
+ if solver. single
515
+ res = [solve (GenericTensorNetwork (problem; optimizer= solver. optimizer), $ (SPROP)(); usecuda= solver. usecuda, T= solver. T)[]. c. data]
516
+ else
517
+ res = collect (solve (GenericTensorNetwork (problem; optimizer= solver. optimizer), $ (PROP)(; tree_storage= true ); usecuda= solver. usecuda, T= solver. T)[]. c)
518
+ end
519
+ return map (x -> ProblemReductions. id_to_config (problem, Int .(x) .+ 1 ), res)
520
+ end
521
+ end
0 commit comments