Skip to content

Commit e71d8ca

Browse files
committed
max3 configs
1 parent 308c20c commit e71d8ca

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

src/arithematics.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ end
191191
# convert from counting type to bitstring type
192192
for (F,TP) in [(:set_type, :ConfigEnumerator), (:sampler_type, :ConfigSampler)]
193193
@eval begin
194-
function $F(::Type{T}, n::Int, nflavor::Int) where {OT, T<:Max2Poly{C,OT} where C}
195-
Max2Poly{$F(n,nflavor),OT}
194+
function $F(::Type{T}, n::Int, nflavor::Int) where {OT, K, T<:TruncatedPoly{K,C,OT} where C}
195+
TruncatedPoly{K, $F(n,nflavor),OT}
196196
end
197197
function $F(::Type{T}, n::Int, nflavor::Int) where {TX, T<:Polynomial{C,TX} where C}
198198
Polynomial{$F(n,nflavor),:x}
@@ -212,8 +212,8 @@ end
212212
function onehotv(::Type{Polynomial{BS,X}}, x, v) where {BS,X}
213213
Polynomial{BS,X}([zero(BS), onehotv(BS, x, v)])
214214
end
215-
function onehotv(::Type{Max2Poly{BS,OS}}, x, v) where {BS,OS}
216-
Max2Poly{BS,OS}(zero(BS), onehotv(BS, x, v),one(OS))
215+
function onehotv(::Type{TruncatedPoly{K,BS,OS}}, x, v) where {K,BS,OS}
216+
TruncatedPoly{K,BS,OS}(ntuple(i->i<K ? zero(BS) : onehotv(BS, x, v), K),one(OS))
217217
end
218218
function onehotv(::Type{CountingTropical{TV,BS}}, x, v) where {TV,BS}
219219
CountingTropical{TV,BS}(one(TV), onehotv(BS, x, v))

src/interfaces.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ function solve(gp::GraphProblem, task; usecuda=false, kwargs...)
2828
return contractx(gp, CountingTropical(1.0); usecuda=usecuda)
2929
elseif task == "counting max2"
3030
return contractx(gp, Max2Poly(0.0, 1.0, 1.0); usecuda=usecuda)
31+
elseif task == "counting max3"
32+
return contractx(gp, TruncatedPoly((0.0, 0.0, 1.0), 1.0); usecuda=usecuda)
3133
elseif task == "counting all"
3234
return graph_polynomial(gp, Val(:polynomial); usecuda=usecuda)
3335
elseif task == "config max"
@@ -36,6 +38,8 @@ function solve(gp::GraphProblem, task; usecuda=false, kwargs...)
3638
return solutions(gp, CountingTropical{Float64,Float64}; all=true, usecuda=usecuda)
3739
elseif task == "configs max2"
3840
return solutions(gp, Max2Poly{Float64,Float64}; all=true, usecuda=usecuda)
41+
elseif task == "configs max3"
42+
return solutions(gp, TruncatedPoly{3,Float64,Float64}; all=true, usecuda=usecuda)
3943
elseif task == "configs all"
4044
return solutions(gp, Polynomial{Float64,:x}; all=true, usecuda=usecuda)
4145
# extra methods

test/interfaces.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ using LightGraphs, Test
1717
res11 = solve(gp, "counting all (finitefield)")[]
1818
res12 = solve(gp, "config max (bounded)")[]
1919
res13 = solve(gp, "configs max (bounded)")[]
20+
res14 = solve(gp, "counting max3")[]
21+
res15 = solve(gp, "configs max3")[]
2022
@test res1.n == 4
2123
@test res2 == 76
2224
@test res3.n == 4 && res3.c == 5
@@ -30,6 +32,9 @@ using LightGraphs, Test
3032
@test res11 == res5
3133
@test res12.c.data res13.c.data
3234
@test res13.c.data == res7.c.data
35+
@test res14.maxorder == 4 && res14.coeffs[1]==30 && res14.coeffs[2] == 30 && res14.coeffs[3]==5
36+
@test all(x->sum(x) == 2, res15.coeffs[1].data) && all(x->sum(x) == 3, res15.coeffs[2].data) && all(x->sum(x) == 4, res15.coeffs[3].data) &&
37+
length(res15.coeffs[1].data) == 30 && length(res15.coeffs[2].data) == 30 && length(res15.coeffs[3].data) == 5
3338
end
3439

3540
@testset "save load" begin

0 commit comments

Comments
 (0)