Skip to content

Commit 87b5a43

Browse files
authored
Support slicing (#5)
* support slicing * support slicing basic * use warn instead of errors * bump version * rm warning test
1 parent fcf9823 commit 87b5a43

File tree

9 files changed

+119
-54
lines changed

9 files changed

+119
-54
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ FFTW = "1.4"
2828
Graphs = "1.4"
2929
Mods = "1.3"
3030
OMEinsum = "0.6.1"
31-
OMEinsumContractionOrders = "0.5"
31+
OMEinsumContractionOrders = "0.6"
3232
Polynomials = "2.0"
3333
Primes = "0.5"
3434
Requires = "1"

src/GraphTensorNetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module GraphTensorNetworks
22

3-
using OMEinsumContractionOrders: OMEinsum
3+
using OMEinsumContractionOrders: SlicedEinsum
44
using Core: Argument
55
using TropicalGEMM, TropicalNumbers
66
using OMEinsum

src/arithematics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ end
130130

131131
Base.length(x::ConfigEnumerator{N}) where N = length(x.data)
132132
Base.getindex(x::ConfigEnumerator, i) = x.data[i]
133-
Base.:(==)(x::ConfigEnumerator{N,S,C}, y::ConfigEnumerator{N,S,C}) where {N,S,C} = x.data == y.data
133+
Base.:(==)(x::ConfigEnumerator{N,S,C}, y::ConfigEnumerator{N,S,C}) where {N,S,C} = Set(x.data) == Set(y.data)
134134

135135
function Base.:+(x::ConfigEnumerator{N,S,C}, y::ConfigEnumerator{N,S,C}) where {N,S,C}
136136
length(x) == 0 && return y

src/bounding.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using OMEinsum: DynamicEinCode
22

3+
export AllConfigs, SingleConfig
4+
35
struct AllConfigs{K} end
46
largest_k(::AllConfigs{K}) where K = K
57
struct SingleConfig end
@@ -57,6 +59,12 @@ struct CacheTree{T}
5759
content::AbstractArray{T}
5860
siblings::Vector{CacheTree{T}}
5961
end
62+
function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)
63+
if length(se.slicing) != 0
64+
@warn "Slicing is not supported for caching! Fallback to `NestedEinsum`."
65+
end
66+
return cached_einsum(se.eins, xs, size_dict)
67+
end
6068
function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
6169
if OMEinsum.isleaf(code)
6270
y = xs[code.tensorindex]
@@ -69,6 +77,12 @@ function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
6977
end
7078

7179
# computed mask tree by back propagation
80+
function generate_masktree(mode, se::SlicedEinsum, cache, mask, size_dict)
81+
if length(se.slicing) != 0
82+
@warn "Slicing is not supported for generating masked tree! Fallback to `NestedEinsum`."
83+
end
84+
return generate_masktree(mode, se.eins, cache, mask, size_dict)
85+
end
7286
function generate_masktree(mode, code::NestedEinsum, cache, mask, size_dict)
7387
if OMEinsum.isleaf(code)
7488
return CacheTree(mask, CacheTree{Bool}[])
@@ -79,6 +93,12 @@ function generate_masktree(mode, code::NestedEinsum, cache, mask, size_dict)
7993
end
8094

8195
# The masked einsum contraction
96+
function masked_einsum(se::SlicedEinsum, @nospecialize(xs), masks, size_dict)
97+
if length(se.slicing) != 0
98+
@warn "Slicing is not supported for masked contraction! Fallback to `NestedEinsum`."
99+
end
100+
return masked_einsum(se.eins, xs, masks, size_dict)
101+
end
82102
function masked_einsum(code::NestedEinsum, @nospecialize(xs), masks, size_dict)
83103
if OMEinsum.isleaf(code)
84104
y = copy(xs[code.tensorindex])
@@ -106,7 +126,7 @@ function bounding_contract(mode::AllConfigs, code::EinCode, @nospecialize(xsa),
106126
LT = OMEinsum.labeltype(code)
107127
bounding_contract(mode, NestedEinsum(NestedEinsum{DynamicEinCode{LT}}.(1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info)
108128
end
109-
function bounding_contract(mode::AllConfigs, code::NestedEinsum, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
129+
function bounding_contract(mode::AllConfigs, code::Union{NestedEinsum,SlicedEinsum}, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
110130
size_dict = size_info===nothing ? Dict{OMEinsum.labeltype(code.eins),Int}() : copy(size_info)
111131
OMEinsum.get_size_dict!(code, xsa, size_dict)
112132
# compute intermediate tensors
@@ -125,7 +145,7 @@ function solution_ad(code::EinCode, @nospecialize(xsa), ymask; size_info=nothing
125145
solution_ad(NestedEinsum(NestedEinsum{DynamicEinCode{LT}}.(1:length(xsa)), code), xsa, ymask; size_info=size_info)
126146
end
127147

128-
function solution_ad(code::NestedEinsum, @nospecialize(xsa), ymask; size_info=nothing)
148+
function solution_ad(code::Union{NestedEinsum,SlicedEinsum}, @nospecialize(xsa), ymask; size_info=nothing)
129149
size_dict = size_info===nothing ? Dict{OMEinsum.labeltype(code.eins),Int}() : copy(size_info)
130150
OMEinsum.get_size_dict!(code, xsa, size_dict)
131151
# compute intermediate tensors
@@ -138,6 +158,10 @@ function solution_ad(code::NestedEinsum, @nospecialize(xsa), ymask; size_info=no
138158
n, read_config!(code, mt, Dict())
139159
end
140160

161+
# get the solution configuration from gradients.
162+
function read_config!(code::SlicedEinsum, mt, out)
163+
read_config!(code.eins, mt, out)
164+
end
141165
function read_config!(code::NestedEinsum, mt, out)
142166
for (arg, ix, sibling) in zip(code.args, getixs(code.eins), mt.siblings)
143167
if OMEinsum.isleaf(arg)

src/configurations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function best_solutions(gp::GraphProblem; all=false, usecuda=false)
1717
T = (all ? set_type : sampler_type)(CountingTropical{Int64}, length(syms), bondsize(gp))
1818
vertex_index = Dict([s=>i for (i, s) in enumerate(syms)])
1919
xst = generate_tensors(l->TropicalF64(1.0), gp)
20-
ymask = trues(fill(2, length(_getiy(gp.code)))...)
20+
ymask = trues(fill(2, length(getiyv(gp.code)))...)
2121
if usecuda
2222
xst = CuArray.(xst)
2323
ymask = CuArray(ymask)
@@ -63,7 +63,7 @@ function bestk_solutions(gp::GraphProblem, k::Int)
6363
syms = symbols(gp)
6464
vertex_index = Dict([s=>i for (i, s) in enumerate(syms)])
6565
xst = generate_tensors(l->TropicalF64(1.0), gp)
66-
ymask = trues(fill(2, length(_getiy(gp.code)))...)
66+
ymask = trues(fill(2, length(getiyv(gp.code)))...)
6767
T = set_type(TruncatedPoly{k,Float64,Float64}, length(syms), bondsize(gp))
6868
xs = generate_tensors(l->onehotv(T, vertex_index[l], 1), gp)
6969
return bounding_contract(AllConfigs{k}(), gp.code, xst, ymask, xs)

src/graph_polynomials.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,11 @@ function _polynomial_single(gp::GraphProblem, ::Type{T}; usecuda, maxorder) wher
6262
return res
6363
end
6464

65-
_getiy(code::EinCode) = getiy(code)
66-
_getiy(code::NestedEinsum) = getiy(code.eins)
6765
function graph_polynomial(gp::GraphProblem, ::Val{:finitefield}; usecuda=false,
6866
maxorder=max_size(gp; usecuda=usecuda), max_iter=100)
6967
TI = Int32 # Int 32 is faster
7068
N = typemax(TI)
71-
YS = fill(Any[], (fill(bondsize(gp), length(_getiy(gp.code)))...,))
69+
YS = fill(Any[], (fill(bondsize(gp), length(getiyv(gp.code)))...,))
7270
local res, respre
7371
for k = 1:max_iter
7472
N = prevprime(N-TI(1))

src/networks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export Independence, MaximalIndependence, Matching, Coloring, optimize_code, set_packing, MaxCut
2-
const EinTypes = Union{EinCode,NestedEinsum}
2+
const EinTypes = Union{EinCode,NestedEinsum,SlicedEinsum}
33

44
abstract type GraphProblem end
55

test/bounding.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using GraphTensorNetworks: cached_einsum, generate_masktree, masked_einsum, Cach
77
size_dict = uniformsize(code, 2)
88
c = cached_einsum(code, xs, size_dict)
99
@test c.content == code(xs...)
10-
mt = generate_masktree(code, c, rand(Bool,2,2,2), size_dict)
10+
mt = generate_masktree(AllConfigs{1}(), code, c, rand(Bool,2,2,2), size_dict)
1111
@test mt isa CacheTree{Bool}
1212
y = masked_einsum(code, xs, mt, size_dict)
1313
@test y isa AbstractArray
@@ -19,7 +19,7 @@ end
1919
xs = map(x->TropicalF64.(x), [rand(1:5,2,2), rand(1:5,2), rand(1:5,2,2), rand(1:5,2,2), rand(1:5,2,2)])
2020
code = ein"((ij,j),jk, kl), ii->kli"
2121
y1 = code(xs...)
22-
y2 = bounding_contract(code, xs, BitArray(ones(Bool,2,2,2)), xs)
22+
y2 = bounding_contract(AllConfigs{1}(), code, xs, BitArray(ones(Bool,2,2,2)), xs)
2323
@test y1 y2
2424
end
2525
rawcode = Independence(random_regular_graph(10, 3); optimizer=nothing).code
@@ -28,10 +28,9 @@ end
2828
length(ix)==1 ? GraphTensorNetworks.misv(TropicalF64(1.0)) : GraphTensorNetworks.misb(TropicalF64)
2929
end
3030
y1 = rawcode(xs...)
31-
y2 = bounding_contract(rawcode, xs, BitArray(fill(true)), xs)
31+
y2 = bounding_contract(AllConfigs{1}(), rawcode, xs, BitArray(fill(true)), xs)
3232
@test y1 y2
3333
y1 = optcode(xs...)
34-
y2 = bounding_contract(optcode, xs, BitArray(fill(true)), xs)
34+
y2 = bounding_contract(AllConfigs{1}(), optcode, xs, BitArray(fill(true)), xs)
3535
@test y1 y2
3636
end
37-

test/interfaces.jl

Lines changed: 82 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,43 +3,45 @@ using Graphs, Test
33

44
@testset "independence problem" begin
55
g = Graphs.smallgraph("petersen")
6-
gp = Independence(g; optimizer=GreedyMethod())
7-
res1 = solve(gp, "size max")[]
8-
res2 = solve(gp, "counting sum")[]
9-
res3 = solve(gp, "counting max")[]
10-
res4 = solve(gp, "counting max2")[]
11-
res5 = solve(gp, "counting all")[]
12-
res6 = solve(gp, "config max")[]
13-
res7 = solve(gp, "configs max")[]
14-
res8 = solve(gp, "configs max2")[]
15-
res9 = solve(gp, "configs all")[]
16-
res10 = solve(gp, "counting all (fft)")[]
17-
res11 = solve(gp, "counting all (finitefield)")[]
18-
res12 = solve(gp, "config max (bounded)")[]
19-
res13 = solve(gp, "configs max (bounded)")[]
20-
res14 = solve(gp, "counting max3")[]
21-
res15 = solve(gp, "configs max3")[]
22-
res16 = solve(gp, "configs max2 (bounded)")[]
23-
res17 = solve(gp, "configs max3 (bounded)")[]
24-
@test res1.n == 4
25-
@test res2 == 76
26-
@test res3.n == 4 && res3.c == 5
27-
@test res4.maxorder == 4 && res4.coeffs[1] == 30 && res4.coeffs[2]==5
28-
@test res5 == Polynomial([1.0, 10.0, 30, 30, 5])
29-
@test res6.c.data res7.c.data
30-
@test all(x->sum(x) == 4, res7.c.data)
31-
@test all(x->sum(x) == 3, res8.coeffs[1].data) && all(x->sum(x) == 4, res8.coeffs[2].data) && length(res8.coeffs[1].data) == 30 && length(res8.coeffs[2].data) == 5
32-
@test all(x->all(c->sum(c) == x[1]-1, x[2].data), enumerate(res9.coeffs))
33-
@test res10 res5
34-
@test res11 == res5
35-
@test res12.c.data res13.c.data
36-
@test res13.c.data == res7.c.data
37-
@test res14.maxorder == 4 && res14.coeffs[1]==30 && res14.coeffs[2] == 30 && res14.coeffs[3]==5
38-
@test all(x->sum(x) == 2, res15.coeffs[1].data) && all(x->sum(x) == 3, res15.coeffs[2].data) && all(x->sum(x) == 4, res15.coeffs[3].data) &&
39-
length(res15.coeffs[1].data) == 30 && length(res15.coeffs[2].data) == 30 && length(res15.coeffs[3].data) == 5
40-
@test all(x->sum(x) == 3, res16.coeffs[1].data) && all(x->sum(x) == 4, res16.coeffs[2].data) && length(res16.coeffs[1].data) == 30 && length(res16.coeffs[2].data) == 5
41-
@test all(x->sum(x) == 2, res17.coeffs[1].data) && all(x->sum(x) == 3, res17.coeffs[2].data) && all(x->sum(x) == 4, res17.coeffs[3].data) &&
42-
length(res17.coeffs[1].data) == 30 && length(res17.coeffs[2].data) == 30 && length(res17.coeffs[3].data) == 5
6+
for optimizer in (GreedyMethod(), TreeSA(ntrials=1))
7+
gp = Independence(g; optimizer=optimizer)
8+
res1 = solve(gp, "size max")[]
9+
res2 = solve(gp, "counting sum")[]
10+
res3 = solve(gp, "counting max")[]
11+
res4 = solve(gp, "counting max2")[]
12+
res5 = solve(gp, "counting all")[]
13+
res6 = solve(gp, "config max")[]
14+
res7 = solve(gp, "configs max")[]
15+
res8 = solve(gp, "configs max2")[]
16+
res9 = solve(gp, "configs all")[]
17+
res10 = solve(gp, "counting all (fft)")[]
18+
res11 = solve(gp, "counting all (finitefield)")[]
19+
res12 = solve(gp, "config max (bounded)")[]
20+
res13 = solve(gp, "configs max (bounded)")[]
21+
res14 = solve(gp, "counting max3")[]
22+
res15 = solve(gp, "configs max3")[]
23+
res16 = solve(gp, "configs max2 (bounded)")[]
24+
res17 = solve(gp, "configs max3 (bounded)")[]
25+
@test res1.n == 4
26+
@test res2 == 76
27+
@test res3.n == 4 && res3.c == 5
28+
@test res4.maxorder == 4 && res4.coeffs[1] == 30 && res4.coeffs[2]==5
29+
@test res5 == Polynomial([1.0, 10.0, 30, 30, 5])
30+
@test res6.c.data res7.c.data
31+
@test all(x->sum(x) == 4, res7.c.data)
32+
@test all(x->sum(x) == 3, res8.coeffs[1].data) && all(x->sum(x) == 4, res8.coeffs[2].data) && length(res8.coeffs[1].data) == 30 && length(res8.coeffs[2].data) == 5
33+
@test all(x->all(c->sum(c) == x[1]-1, x[2].data), enumerate(res9.coeffs))
34+
@test res10 res5
35+
@test res11 == res5
36+
@test res12.c.data res13.c.data
37+
@test res13.c == res7.c
38+
@test res14.maxorder == 4 && res14.coeffs[1]==30 && res14.coeffs[2] == 30 && res14.coeffs[3]==5
39+
@test all(x->sum(x) == 2, res15.coeffs[1].data) && all(x->sum(x) == 3, res15.coeffs[2].data) && all(x->sum(x) == 4, res15.coeffs[3].data) &&
40+
length(res15.coeffs[1].data) == 30 && length(res15.coeffs[2].data) == 30 && length(res15.coeffs[3].data) == 5
41+
@test all(x->sum(x) == 3, res16.coeffs[1].data) && all(x->sum(x) == 4, res16.coeffs[2].data) && length(res16.coeffs[1].data) == 30 && length(res16.coeffs[2].data) == 5
42+
@test all(x->sum(x) == 2, res17.coeffs[1].data) && all(x->sum(x) == 3, res17.coeffs[2].data) && all(x->sum(x) == 4, res17.coeffs[3].data) &&
43+
length(res17.coeffs[1].data) == 30 && length(res17.coeffs[2].data) == 30 && length(res17.coeffs[3].data) == 5
44+
end
4345
end
4446

4547
@testset "save load" begin
@@ -79,4 +81,46 @@ end
7981
save_configs("_test.txt", m; format=:text)
8082
mb = load_configs("_test.txt"; format=:text, nflavors=3)
8183
@test mb == m
82-
end
84+
end
85+
86+
@testset "slicing" begin
87+
g = Graphs.smallgraph("petersen")
88+
gp = Independence(g; optimizer=TreeSA(nslices=5, ntrials=1))
89+
res1 = solve(gp, "size max")[]
90+
res2 = solve(gp, "counting sum")[]
91+
res3 = solve(gp, "counting max")[]
92+
res4 = solve(gp, "counting max2")[]
93+
res5 = solve(gp, "counting all")[]
94+
res6 = solve(gp, "config max")[]
95+
res7 = solve(gp, "configs max")[]
96+
res8 = solve(gp, "configs max2")[]
97+
res9 = solve(gp, "configs all")[]
98+
res10 = solve(gp, "counting all (fft)")[]
99+
res11 = solve(gp, "counting all (finitefield)")[]
100+
res12 = solve(gp, "config max (bounded)")[]
101+
res13 = solve(gp, "configs max (bounded)")[]
102+
res14 = solve(gp, "counting max3")[]
103+
res15 = solve(gp, "configs max3")[]
104+
res16 = solve(gp, "configs max2 (bounded)")[]
105+
res17 = solve(gp, "configs max3 (bounded)")[]
106+
@test res1.n == 4
107+
@test res2 == 76
108+
@test res3.n == 4 && res3.c == 5
109+
@test res4.maxorder == 4 && res4.coeffs[1] == 30 && res4.coeffs[2]==5
110+
@test res5 == Polynomial([1.0, 10.0, 30, 30, 5])
111+
@test res6.c.data res7.c.data
112+
@test all(x->sum(x) == 4, res7.c.data)
113+
@test all(x->sum(x) == 3, res8.coeffs[1].data) && all(x->sum(x) == 4, res8.coeffs[2].data) && length(res8.coeffs[1].data) == 30 && length(res8.coeffs[2].data) == 5
114+
@test all(x->all(c->sum(c) == x[1]-1, x[2].data), enumerate(res9.coeffs))
115+
@test res10 res5
116+
@test res11 == res5
117+
@test res12.c.data res13.c.data
118+
@test res13.c == res7.c
119+
@test res14.maxorder == 4 && res14.coeffs[1]==30 && res14.coeffs[2] == 30 && res14.coeffs[3]==5
120+
@test all(x->sum(x) == 2, res15.coeffs[1].data) && all(x->sum(x) == 3, res15.coeffs[2].data) && all(x->sum(x) == 4, res15.coeffs[3].data) &&
121+
length(res15.coeffs[1].data) == 30 && length(res15.coeffs[2].data) == 30 && length(res15.coeffs[3].data) == 5
122+
@test all(x->sum(x) == 3, res16.coeffs[1].data) && all(x->sum(x) == 4, res16.coeffs[2].data) && length(res16.coeffs[1].data) == 30 && length(res16.coeffs[2].data) == 5
123+
@test all(x->sum(x) == 2, res17.coeffs[1].data) && all(x->sum(x) == 3, res17.coeffs[2].data) && all(x->sum(x) == 4, res17.coeffs[3].data) &&
124+
length(res17.coeffs[1].data) == 30 && length(res17.coeffs[2].data) == 30 && length(res17.coeffs[3].data) == 5
125+
end
126+

0 commit comments

Comments
 (0)