Skip to content

Commit 3647afe

Browse files
committed
2 parents a2be767 + d9065f1 commit 3647afe

File tree

7 files changed

+45
-14
lines changed

7 files changed

+45
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphTensorNetworks"
22
uuid = "0978c8c2-34f6-49c7-9826-ea2cc20dabd2"
33
authors = ["GiggleLiu <[email protected]> and contributors"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/arithematics.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,4 +219,6 @@ function onehotv(::Type{CountingTropical{TV,BS}}, x, v) where {TV,BS}
219219
CountingTropical{TV,BS}(one(TV), onehotv(BS, x, v))
220220
end
221221
onehotv(::Type{ConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigEnumerator([onehotv(StaticElementVector{N,S,C}, i, v)])
222-
onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampler(onehotv(StaticElementVector{N,S,C}, i, v))
222+
onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampler(onehotv(StaticElementVector{N,S,C}, i, v))
223+
Base.transpose(c::ConfigEnumerator) = c
224+
Base.copy(c::ConfigEnumerator) = ConfigEnumerator(copy(c.data))

src/bounding.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
using OMEinsum: DynamicEinCode
22

3+
struct AllConfigs{K} end
4+
largest_k(::AllConfigs{K}) where K = K
5+
struct SingleConfig end
6+
37
"""
48
backward_tropical(mode, ixs, xs, iy, y, ymask, size_dict)
59
@@ -17,11 +21,11 @@ function backward_tropical(mode, ixs, @nospecialize(xs::Tuple), iy, @nospecializ
1721
nixs = OMEinsum._insertat(ixs, i, iy)
1822
nxs = OMEinsum._insertat( xs, i, y)
1923
niy = ixs[i]
20-
if mode == :all
24+
if mode isa AllConfigs
2125
mask = zeros(Bool, size(xs[i]))
22-
mask .= inv.(einsum(EinCode(nixs, niy), nxs, size_dict)) .== xs[i]
26+
mask .= inv.(einsum(EinCode(nixs, niy), nxs, size_dict)) .<= xs[i] .* Tropical(largest_k(mode)-1)
2327
push!(masks, mask)
24-
elseif mode == :single # wrong, need `B` matching `A`.
28+
elseif mode isa SingleConfig
2529
A = zeros(eltype(xs[i]), size(xs[i]))
2630
A = einsum(EinCode(nixs, niy), nxs, size_dict)
2731
push!(masks, onehotmask(A, xs[i]))
@@ -65,12 +69,12 @@ function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
6569
end
6670

6771
# computed mask tree by back propagation
68-
function generate_masktree(code::NestedEinsum, cache, mask, size_dict, mode=:all)
72+
function generate_masktree(mode, code::NestedEinsum, cache, mask, size_dict)
6973
if OMEinsum.isleaf(code)
7074
return CacheTree(mask, CacheTree{Bool}[])
7175
else
7276
submasks = backward_tropical(mode, getixs(code.eins), (getfield.(cache.siblings, :content)...,), OMEinsum.getiy(code.eins), cache.content, mask, size_dict)
73-
return CacheTree(mask, generate_masktree.(code.args, cache.siblings, submasks, Ref(size_dict), mode))
77+
return CacheTree(mask, generate_masktree.(Ref(mode), code.args, cache.siblings, submasks, Ref(size_dict)))
7478
end
7579
end
7680

@@ -89,27 +93,28 @@ function masked_einsum(code::NestedEinsum, @nospecialize(xs), masks, size_dict)
8993
end
9094

9195
"""
92-
bounding_contract(code, xsa, ymask, xsb; size_info=nothing)
96+
bounding_contract(mode, code, xsa, ymask, xsb; size_info=nothing)
9397
9498
Contraction method with bounding.
9599
100+
* `mode` is a `AllConfigs{K}` instance, where `MIS-K+1` is the largest IS size that you care about.
96101
* `xsa` are input tensors for bounding, e.g. tropical tensors,
97102
* `xsb` are input tensors for computing, e.g. tensors elements are counting tropical with set algebra,
98103
* `ymask` is the initial gradient mask for the output tensor.
99104
"""
100-
function bounding_contract(code::EinCode, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
105+
function bounding_contract(mode::AllConfigs, code::EinCode, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
101106
LT = OMEinsum.labeltype(code)
102-
bounding_contract(NestedEinsum(NestedEinsum{DynamicEinCode{LT}}.(1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info)
107+
bounding_contract(mode, NestedEinsum(NestedEinsum{DynamicEinCode{LT}}.(1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info)
103108
end
104-
function bounding_contract(code::NestedEinsum, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
109+
function bounding_contract(mode::AllConfigs, code::NestedEinsum, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
105110
size_dict = size_info===nothing ? Dict{OMEinsum.labeltype(code.eins),Int}() : copy(size_info)
106111
OMEinsum.get_size_dict!(code, xsa, size_dict)
107112
# compute intermediate tensors
108113
@debug "caching einsum..."
109114
c = cached_einsum(code, xsa, size_dict)
110115
# compute masks from cached tensors
111116
@debug "generating masked tree..."
112-
mt = generate_masktree(code, c, ymask, size_dict, :all)
117+
mt = generate_masktree(mode, code, c, ymask, size_dict)
113118
# compute results with masks
114119
masked_einsum(code, xsb, mt, size_dict)
115120
end
@@ -129,7 +134,7 @@ function solution_ad(code::NestedEinsum, @nospecialize(xsa), ymask; size_info=no
129134
n = asscalar(c.content)
130135
# compute masks from cached tensors
131136
@debug "generating masked tree..."
132-
mt = generate_masktree(code, c, ymask, size_dict, :single)
137+
mt = generate_masktree(SingleConfig(), code, c, ymask, size_dict)
133138
n, read_config!(code, mt, Dict())
134139
end
135140

src/configurations.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
export best_solutions, best2_solutions, solutions, all_solutions
2+
export bestk_solutions
23

34
"""
45
best_solutions(problem; all=false, usecuda=false)
@@ -23,7 +24,7 @@ function best_solutions(gp::GraphProblem; all=false, usecuda=false)
2324
end
2425
if all
2526
xs = generate_tensors(l->onehotv(T, vertex_index[l], 1), gp)
26-
return bounding_contract(gp.code, xst, ymask, xs)
27+
return bounding_contract(AllConfigs{1}(), gp.code, xst, ymask, xs)
2728
else
2829
@assert ndims(ymask) == 0
2930
t, res = solution_ad(gp.code, xst, ymask)
@@ -58,6 +59,16 @@ Finding optimal and suboptimal solutions.
5859
"""
5960
best2_solutions(gp::GraphProblem; all=true, usecuda=false) = solutions(gp, Max2Poly{Float64,Float64}; all=all, usecuda=usecuda)
6061

62+
function bestk_solutions(gp::GraphProblem, k::Int)
63+
syms = symbols(gp)
64+
vertex_index = Dict([s=>i for (i, s) in enumerate(syms)])
65+
xst = generate_tensors(l->TropicalF64(1.0), gp)
66+
ymask = trues(fill(2, length(_getiy(gp.code)))...)
67+
T = set_type(TruncatedPoly{k,Float64,Float64}, length(syms), bondsize(gp))
68+
xs = generate_tensors(l->onehotv(T, vertex_index[l], 1), gp)
69+
return bounding_contract(AllConfigs{k}(), gp.code, xst, ymask, xs)
70+
end
71+
6172
"""
6273
all_solutions(problem)
6374

src/interfaces.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ function solve(gp::GraphProblem, task; usecuda=false, kwargs...)
5151
return best_solutions(gp; all=false, usecuda=usecuda)
5252
elseif task == "configs max (bounded)"
5353
return best_solutions(gp; all=true, usecuda=usecuda)
54+
elseif task == "configs max2 (bounded)"
55+
return bestk_solutions(gp, 2)
56+
elseif task == "configs max3 (bounded)"
57+
return bestk_solutions(gp, 3)
5458
else
5559
error("unknown task $task.")
5660
end

test/configurations.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using GraphTensorNetworks, Test, Graphs
22
using OMEinsum
33
using TropicalNumbers: CountingTropicalF64
4+
using OMEinsumContractionOrders: uniformsize
45

56
@testset "Config types" begin
67
T = sampler_type(CountingTropical{Float32}, 5, 2)
@@ -45,10 +46,13 @@ end
4546
@test res5.n == res0
4647
@test res5.c.data res2.c.data
4748
res6 = best2_solutions(code; all=true)[]
49+
res6_ = bestk_solutions(code, 2)[]
4850
res7 = all_solutions(code)[]
4951
idp = graph_polynomial(code, Val(:finitefield))[]
5052
@test all(x->x res7.coeffs[end-1].data, res6.coeffs[1].data)
5153
@test all(x->x res7.coeffs[end].data, res6.coeffs[2].data)
54+
@test all(x->x res7.coeffs[end-1].data, res6_.coeffs[1].data)
55+
@test all(x->x res7.coeffs[end].data, res6_.coeffs[2].data)
5256
for (i, (s, c)) in enumerate(zip(res7.coeffs, idp.coeffs))
5357
@test length(s) == c
5458
@test all(x->count_ones(x)==(i-1), s.data)

test/interfaces.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ using Graphs, Test
1919
res13 = solve(gp, "configs max (bounded)")[]
2020
res14 = solve(gp, "counting max3")[]
2121
res15 = solve(gp, "configs max3")[]
22+
res16 = solve(gp, "configs max2 (bounded)")[]
23+
res17 = solve(gp, "configs max3 (bounded)")[]
2224
@test res1.n == 4
2325
@test res2 == 76
2426
@test res3.n == 4 && res3.c == 5
@@ -35,6 +37,9 @@ using Graphs, Test
3537
@test res14.maxorder == 4 && res14.coeffs[1]==30 && res14.coeffs[2] == 30 && res14.coeffs[3]==5
3638
@test all(x->sum(x) == 2, res15.coeffs[1].data) && all(x->sum(x) == 3, res15.coeffs[2].data) && all(x->sum(x) == 4, res15.coeffs[3].data) &&
3739
length(res15.coeffs[1].data) == 30 && length(res15.coeffs[2].data) == 30 && length(res15.coeffs[3].data) == 5
40+
@test all(x->sum(x) == 3, res16.coeffs[1].data) && all(x->sum(x) == 4, res16.coeffs[2].data) && length(res16.coeffs[1].data) == 30 && length(res16.coeffs[2].data) == 5
41+
@test all(x->sum(x) == 2, res17.coeffs[1].data) && all(x->sum(x) == 3, res17.coeffs[2].data) && all(x->sum(x) == 4, res17.coeffs[3].data) &&
42+
length(res17.coeffs[1].data) == 30 && length(res17.coeffs[2].data) == 30 && length(res17.coeffs[3].data) == 5
3843
end
3944

4045
@testset "save load" begin

0 commit comments

Comments
 (0)