Skip to content

Commit 7e0cdab

Browse files
committed
coloring configurations
1 parent 82b7439 commit 7e0cdab

File tree

7 files changed

+169
-72
lines changed

7 files changed

+169
-72
lines changed

src/arithematics.jl

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export is_commutative_semiring
2-
export Max2Poly, Polynomial, Tropical, CountingTropical, StaticBitVector, Mod, ConfigEnumerator, onehotv, ConfigSampler
3-
export bitstringset_type, bitstringsampler_type
2+
export Max2Poly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod, ConfigEnumerator, onehotv, ConfigSampler
3+
export set_type, sampler_type
44

55
using Polynomials: Polynomial
66
using TropicalNumbers: Tropical, CountingTropical
@@ -102,56 +102,56 @@ function Base.show(io::IO, ::MIME"text/plain", x::Max2Poly)
102102
end
103103
end
104104

105-
struct ConfigEnumerator{N,C}
106-
data::Vector{StaticBitVector{N,C}}
105+
struct ConfigEnumerator{N,S,C}
106+
data::Vector{StaticElementVector{N,S,C}}
107107
end
108108

109109
Base.length(x::ConfigEnumerator{N}) where N = length(x.data)
110-
Base.:(==)(x::ConfigEnumerator{N,C}, y::ConfigEnumerator{N,C}) where {N,C} = x.data == y.data
110+
Base.:(==)(x::ConfigEnumerator{N,S,C}, y::ConfigEnumerator{N,S,C}) where {N,S,C} = x.data == y.data
111111

112-
function Base.:+(x::ConfigEnumerator{N,C}, y::ConfigEnumerator{N,C}) where {N,C}
112+
function Base.:+(x::ConfigEnumerator{N,S,C}, y::ConfigEnumerator{N,S,C}) where {N,S,C}
113113
length(x) == 0 && return y
114114
length(y) == 0 && return x
115-
return ConfigEnumerator{N,C}(vcat(x.data, y.data))
115+
return ConfigEnumerator{N,S,C}(vcat(x.data, y.data))
116116
end
117117

118-
function Base.:*(x::ConfigEnumerator{L,C}, y::ConfigEnumerator{L,C}) where {L,C}
118+
function Base.:*(x::ConfigEnumerator{L,S,C}, y::ConfigEnumerator{L,S,C}) where {L,S,C}
119119
M, N = length(x), length(y)
120120
M == 0 && return x
121121
N == 0 && return y
122-
z = Vector{StaticBitVector{L,C}}(undef, M*N)
122+
z = Vector{StaticElementVector{L,S,C}}(undef, M*N)
123123
@inbounds for j=1:N, i=1:M
124124
z[(j-1)*M+i] = x.data[i] | y.data[j]
125125
end
126-
return ConfigEnumerator{L,C}(z)
126+
return ConfigEnumerator{L,S,C}(z)
127127
end
128128

129-
Base.zero(::Type{ConfigEnumerator{N,C}}) where {N,C} = ConfigEnumerator{N,C}(StaticBitVector{N,C}[])
130-
Base.one(::Type{ConfigEnumerator{N,C}}) where {N,C} = ConfigEnumerator{N,C}([staticfalses(StaticBitVector{N,C})])
131-
Base.zero(::ConfigEnumerator{N,C}) where {N,C} = zero(ConfigEnumerator{N,C})
132-
Base.one(::ConfigEnumerator{N,C}) where {N,C} = one(ConfigEnumerator{N,C})
129+
Base.zero(::Type{ConfigEnumerator{N,S,C}}) where {N,S,C} = ConfigEnumerator{N,S,C}(StaticElementVector{N,S,C}[])
130+
Base.one(::Type{ConfigEnumerator{N,S,C}}) where {N,S,C} = ConfigEnumerator{N,S,C}([zero(StaticElementVector{N,S,C})])
131+
Base.zero(::ConfigEnumerator{N,S,C}) where {N,S,C} = zero(ConfigEnumerator{N,S,C})
132+
Base.one(::ConfigEnumerator{N,S,C}) where {N,S,C} = one(ConfigEnumerator{N,S,C})
133133
Base.show(io::IO, x::ConfigEnumerator) = print(io, "{", join(x.data, ", "), "}")
134134
Base.show(io::IO, ::MIME"text/plain", x::ConfigEnumerator) = Base.show(io, x)
135135

136136
# the algebra sampling one of the configurations
137-
struct ConfigSampler{N,C}
138-
data::StaticBitVector{N,C}
137+
struct ConfigSampler{N,S,C}
138+
data::StaticElementVector{N,S,C}
139139
end
140140

141-
Base.:(==)(x::ConfigSampler{N,C}, y::ConfigSampler{N,C}) where {N,C} = x.data == y.data
141+
Base.:(==)(x::ConfigSampler{N,S,C}, y::ConfigSampler{N,S,C}) where {N,S,C} = x.data == y.data
142142

143-
function Base.:+(x::ConfigSampler{N,C}, y::ConfigSampler{N,C}) where {N,C} # biased sampling: return `x`, maybe using random sampler is better.
143+
function Base.:+(x::ConfigSampler{N,S,C}, y::ConfigSampler{N,S,C}) where {N,S,C} # biased sampling: return `x`, maybe using random sampler is better.
144144
return x
145145
end
146146

147-
function Base.:*(x::ConfigSampler{L,C}, y::ConfigSampler{L,C}) where {L,C}
147+
function Base.:*(x::ConfigSampler{L,S,C}, y::ConfigSampler{L,S,C}) where {L,S,C}
148148
ConfigSampler(x.data | y.data)
149149
end
150150

151-
Base.zero(::Type{ConfigSampler{N,C}}) where {N,C} = ConfigSampler{N,C}(statictrues(StaticBitVector{N,C}))
152-
Base.one(::Type{ConfigSampler{N,C}}) where {N,C} = ConfigSampler{N,C}(staticfalses(StaticBitVector{N,C}))
153-
Base.zero(::ConfigSampler{N,C}) where {N,C} = zero(ConfigSampler{N,C})
154-
Base.one(::ConfigSampler{N,C}) where {N,C} = one(ConfigSampler{N,C})
151+
Base.zero(::Type{ConfigSampler{N,S,C}}) where {N,S,C} = ConfigSampler{N,S,C}(statictrues(StaticElementVector{N,S,C}))
152+
Base.one(::Type{ConfigSampler{N,S,C}}) where {N,S,C} = ConfigSampler{N,S,C}(staticfalses(StaticElementVector{N,S,C}))
153+
Base.zero(::ConfigSampler{N,S,C}) where {N,S,C} = zero(ConfigSampler{N,S,C})
154+
Base.one(::ConfigSampler{N,S,C}) where {N,S,C} = one(ConfigSampler{N,S,C})
155155

156156
# A patch to make `Polynomial{ConfigEnumerator}` work
157157
function Base.:*(a::Int, y::ConfigEnumerator)
@@ -166,33 +166,34 @@ function Base.:*(a::Int, y::ConfigSampler)
166166
end
167167

168168
# convert from counting type to bitstring type
169-
for (F,TP) in [(:bitstringset_type, :ConfigEnumerator), (:bitstringsampler_type, :ConfigSampler)]
169+
for (F,TP) in [(:set_type, :ConfigEnumerator), (:sampler_type, :ConfigSampler)]
170170
@eval begin
171-
function $F(::Type{T}, n::Int) where {OT, T<:Max2Poly{C,OT} where C}
172-
Max2Poly{$F(n),OT}
171+
function $F(::Type{T}, n::Int, nflavor::Int) where {OT, T<:Max2Poly{C,OT} where C}
172+
Max2Poly{$F(n,nflavor),OT}
173173
end
174-
function $F(::Type{T}, n::Int) where {TX, T<:Polynomial{C,TX} where C}
175-
Polynomial{$F(n),:x}
174+
function $F(::Type{T}, n::Int, nflavor::Int) where {TX, T<:Polynomial{C,TX} where C}
175+
Polynomial{$F(n,nflavor),:x}
176176
end
177-
function $F(::Type{T}, n::Int) where {TV, T<:CountingTropical{TV}}
178-
CountingTropical{TV, $F(n)}
177+
function $F(::Type{T}, n::Int, nflavor::Int) where {TV, T<:CountingTropical{TV}}
178+
CountingTropical{TV, $F(n,nflavor)}
179179
end
180-
function $F(n::Integer)
181-
C = _nints(n)
182-
return $TP{n, C}
180+
function $F(n::Integer, nflavor::Integer)
181+
s = ceil(Int, log2(nflavor))
182+
c = _nints(n,s)
183+
return $TP{n,s,c}
183184
end
184185
end
185186
end
186187

187188
# utilities for creating onehot vectors
188-
function onehotv(::Type{Polynomial{BS,X}}, x) where {BS,X}
189-
Polynomial{BS,X}([zero(BS), onehotv(BS, x)])
189+
function onehotv(::Type{Polynomial{BS,X}}, x, v) where {BS,X}
190+
Polynomial{BS,X}([zero(BS), onehotv(BS, x, v)])
190191
end
191-
function onehotv(::Type{Max2Poly{BS,OS}}, x) where {BS,OS}
192-
Max2Poly{BS,OS}(zero(BS), onehotv(BS, x),one(OS))
192+
function onehotv(::Type{Max2Poly{BS,OS}}, x, v) where {BS,OS}
193+
Max2Poly{BS,OS}(zero(BS), onehotv(BS, x, v),one(OS))
193194
end
194-
function onehotv(::Type{CountingTropical{TV,BS}}, x) where {TV,BS}
195-
CountingTropical{TV,BS}(one(TV), onehotv(BS, x))
195+
function onehotv(::Type{CountingTropical{TV,BS}}, x, v) where {TV,BS}
196+
CountingTropical{TV,BS}(one(TV), onehotv(BS, x, v))
196197
end
197-
onehotv(::Type{ConfigEnumerator{N,C}}, i::Integer) where {N,C} = ConfigEnumerator([onehotv(StaticBitVector{N,C}, i)])
198-
onehotv(::Type{ConfigSampler{N,C}}, i::Integer) where {N,C} = ConfigSampler(onehotv(StaticBitVector{N,C}, i))
198+
onehotv(::Type{ConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigEnumerator([onehotv(StaticElementVector{N,S,C}, i, v)])
199+
onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampler(onehotv(StaticElementVector{N,S,C}, i, v))

src/bitvector.jl

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,104 @@
11
# StaticBitVector
2-
export StaticBitVector
2+
export StaticBitVector, StaticElementVector
33

4-
struct StaticBitVector{N,C}
4+
"""
5+
StaticElementVector{N,S,C}
6+
7+
`N` is the length of vector, `C` is the size of storage in unit of `UInt64`, `S` is the stride defined as the `log2(# of flavors)`.
8+
"""
9+
struct StaticElementVector{N,S,C}
510
data::NTuple{C,UInt64}
611
end
7-
function StaticBitVector(x::AbstractVector)
12+
13+
Base.length(::StaticElementVector{N,S,C}) where {N,S,C} = N
14+
Base.:(==)(x::StaticElementVector, y::AbstractVector) = [x...] == [y...]
15+
Base.:(==)(x::AbstractVector, y::StaticElementVector) = [x...] == [y...]
16+
Base.:(==)(x::StaticElementVector{N,S,C}, y::StaticElementVector{N,S,C}) where {N,S,C} = x.data == y.data
17+
@inline function Base.getindex(x::StaticElementVector{N,S,C}, i::Integer) where {N,S,C}
18+
@boundscheck i <= N || throw(BoundsError(x, i))
19+
i1 = (i-1)*S+1 # start point
20+
i2 = i*S # stop point
21+
ii1 = (i1-1) ÷ 64
22+
ii2 = (i2-1) ÷ 64
23+
@inbounds if ii1 == ii2
24+
(x.data[ii1+1] >> (i1-ii1*64-1)) & (1<<S - 1)
25+
else # cross two integers
26+
(x.data[ii1+1] >> (i1-ii*64-S+1)) | (x.data[ii2+1] & (1<<(i2-ii1*64) - 1))
27+
end
28+
end
29+
function StaticElementVector(nflavor::Int, x::AbstractVector)
830
N = length(x)
9-
StaticBitVector{N,_nints(N)}((convert(BitVector, x).chunks...,))
31+
S = ceil(Int,log2(nflavor)) # sometimes can not devide 64.
32+
convert(StaticElementVector{N,S,_nints(N,S)}, x)
1033
end
11-
function Base.convert(::Type{StaticBitVector{N,C}}, x::AbstractVector) where {N,C}
34+
function Base.convert(::Type{StaticElementVector{N,S,C}}, x::AbstractVector) where {N,S,C}
1235
@assert length(x) == N
13-
StaticBitVector(x)
36+
data = zeros(UInt64,C)
37+
for i=1:N
38+
i1 = (i-1)*S+1 # start point
39+
i2 = i*S # stop point
40+
ii1 = (i1-1) ÷ 64
41+
ii2 = (i2-1) ÷ 64
42+
@inbounds if ii1 == ii2
43+
data[ii1+1] |= UInt64(x[i]) << (i1-ii1*64-1)
44+
else # cross two integers
45+
data[ii1+1] |= UInt64(x[i]) << (i1-ii1*64-1)
46+
data[ii2+1] |= UInt64(x[i]) >> (i2-ii1*64)
47+
end
48+
end
49+
return StaticElementVector{N,S,C}((data...,))
50+
end
51+
# joining two element sets
52+
Base.:(|)(x::StaticElementVector{N,S,C}, y::StaticElementVector{N,S,C}) where {N,S,C} = StaticElementVector{N,S,C}(x.data .| y.data)
53+
# intersection of two element sets
54+
Base.:(&)(x::StaticElementVector{N,S,C}, y::StaticElementVector{N,S,C}) where {N,S,C} = StaticElementVector{N,S,C}(x.data .& y.data)
55+
# difference of two element sets
56+
Base.:()(x::StaticElementVector{N,S,C}, y::StaticElementVector{N,S,C}) where {N,S,C} = StaticElementVector{N,S,C}(x.data .⊻ y.data)
57+
58+
function onehotv(::Type{StaticElementVector{N,S,C}}, i, v) where {N,S,C}
59+
x = zeros(Int,N)
60+
x[i] = v
61+
return convert(StaticElementVector{N,S,C}, x)
1462
end
15-
_nints(x) = (x-1)÷64+1
16-
Base.length(::StaticBitVector{N,C}) where {N,C} = N
17-
Base.:(==)(x::StaticBitVector, y::AbstractVector) = [x...] == [y...]
18-
Base.:(==)(x::AbstractVector, y::StaticBitVector) = [x...] == [y...]
19-
Base.:(==)(x::StaticBitVector, y::StaticBitVector) = [x...] == [y...]
20-
function Base.getindex(x::StaticBitVector{N,C}, i::Integer) where {N,C}
63+
64+
##### BitVectors
65+
const StaticBitVector{N,C} = StaticElementVector{N,1,C}
66+
@inline function Base.getindex(x::StaticBitVector{N,C}, i::Integer) where {N,C}
67+
@boundscheck i <= N || throw(BoundsError(x, i)) # TODO: make this @boundscheck work.
2168
i -= 1
2269
ii = i ÷ 64
2370
(x.data[ii+1] >> (i-ii*64)) & 1
2471
end
25-
Base.:(|)(x::StaticBitVector{N,C}, y::StaticBitVector{N,C}) where {N,C} = StaticBitVector{N,C}(x.data .| y.data)
26-
Base.:(&)(x::StaticBitVector{N,C}, y::StaticBitVector{N,C}) where {N,C} = StaticBitVector{N,C}(x.data .& y.data)
27-
Base.:()(x::StaticBitVector{N,C}, y::StaticBitVector{N,C}) where {N,C} = StaticBitVector{N,C}(x.data .⊻ y.data)
28-
@generated function staticfalses(::Type{StaticBitVector{N,C}}) where {N,C}
29-
Expr(:call, :(StaticBitVector{$N,$C}), Expr(:tuple, zeros(UInt64, C)...))
72+
73+
function StaticBitVector(x::AbstractVector)
74+
N = length(x)
75+
StaticBitVector{N,_nints(N,1)}((convert(BitVector, x).chunks...,))
76+
end
77+
function Base.convert(::Type{StaticBitVector{N,C}}, x::AbstractVector) where {N,C}
78+
@assert length(x) == N
79+
StaticBitVector(x)
80+
end
81+
_nints(x,s) = (x*s-1)÷64+1
82+
83+
@generated function Base.zero(::Type{StaticElementVector{N,S,C}}) where {N,S,C}
84+
Expr(:call, :(StaticElementVector{$N,$S,$C}), Expr(:tuple, zeros(UInt64, C)...))
3085
end
86+
staticfalses(::Type{StaticBitVector{N,C}}) where {N,C} = zero(StaticBitVector{N,C})
3187
@generated function statictrues(::Type{StaticBitVector{N,C}}) where {N,C}
3288
Expr(:call, :(StaticBitVector{$N,$C}), Expr(:tuple, fill(typemax(UInt64), C)...))
3389
end
90+
onehotv(::Type{StaticBitVector{N,C}}, i, v) where {N,C} = v > 0 ? onehotv(StaticBitVector{N,C}, i) : zero(StaticBitVector{N,C})
3491
function onehotv(::Type{StaticBitVector{N,C}}, i) where {N,C}
3592
x = falses(N)
3693
x[i] = true
3794
return StaticBitVector(x)
3895
end
39-
function Base.iterate(x::StaticBitVector{N,C}, state=1) where {N,C}
96+
function Base.iterate(x::StaticElementVector{N,S,C}, state=1) where {N,S,C}
4097
if state > N
4198
return nothing
4299
else
43100
return x[state], state+1
44101
end
45102
end
46103

47-
Base.show(io::IO, t::StaticBitVector) = Base.print(io, "$(join(Int.(t), ""))")
104+
Base.show(io::IO, t::StaticElementVector) = Base.print(io, "$(join(Int.(t), ""))")

src/configurations.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function optimalsolutions(gp::GraphProblem; all=false, usecuda=false)
1313
throw(ArgumentError("ConfigEnumerator can not be computed on GPU!"))
1414
end
1515
syms = symbols(gp)
16-
T = (all ? bitstringset_type : bitstringsampler_type)(CountingTropical{Int64}, length(syms))
16+
T = (all ? set_type : sampler_type)(CountingTropical{Int64}, length(syms), bondsize(gp))
1717
vertex_index = Dict([s=>i for (i, s) in enumerate(syms)])
1818
xst = generate_tensors(l->TropicalF64(1.0), gp)
1919
ymask = trues(fill(2, length(OMEinsum.getiy(flatten(gp.code))))...)
@@ -22,7 +22,7 @@ function optimalsolutions(gp::GraphProblem; all=false, usecuda=false)
2222
ymask = CuArray(ymask)
2323
end
2424
if all
25-
xs = generate_tensors(l->onehotv(T, vertex_index[l]), gp)
25+
xs = generate_tensors(l->onehotv(T, vertex_index[l], 1), gp)
2626
return bounding_contract(gp.code, xst, ymask, xs)
2727
else
2828
@assert ndims(ymask) == 0
@@ -52,13 +52,26 @@ function solutions(gp::GraphProblem, ::Type{BT}; all=false, usecuda=false) where
5252
end
5353

5454
# return a mapping from label to variable `x`
55-
function fx_solutions(gp::GraphProblem, ::Type{BT}, all::Bool) where BT
55+
for GP in [:Independence, :Matching, :MaximalIndependence, :MaxCut]
56+
@eval function fx_solutions(gp::$GP, ::Type{BT}, all::Bool) where BT
57+
syms = symbols(gp)
58+
T = (all ? set_type : sampler_type)(BT, length(syms), bondsize(gp))
59+
vertex_index = Dict([s=>i for (i, s) in enumerate(syms)])
60+
return l->onehotv(T, vertex_index[l], 1)
61+
end
62+
end
63+
function fx_solutions(gp::Coloring{K}, ::Type{BT}, all::Bool) where {K,BT}
5664
syms = symbols(gp)
57-
T = (all ? bitstringset_type : bitstringsampler_type)(BT, length(syms))
65+
T = (all ? set_type : sampler_type)(BT, length(syms), bondsize(gp))
5866
vertex_index = Dict([s=>i for (i, s) in enumerate(syms)])
59-
return l->onehotv(T, vertex_index[l])
67+
return function (l)
68+
map(1:K) do k
69+
onehotv(T, vertex_index[l], k)
70+
end
71+
end
6072
end
61-
for GP in [:Independence, :Matching, :MaximalIndependence]
73+
74+
for GP in [:Independence, :Matching, :MaximalIndependence, :Coloring]
6275
@eval symbols(gp::$GP) = labels(gp.code)
6376
end
6477
symbols(gp::MaxCut) = collect(OMEinsum.getixs(OMEinsum.flatten(gp.code)))

src/graph_polynomials.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ function generate_tensors(fx, c::Coloring{K}) where K
134134
return map(ixs) do ix
135135
# if the tensor rank is 1, create a vertex tensor.
136136
# otherwise the tensor rank must be 2, create a bond tensor.
137-
length(ix)==1 ? coloringv(f(ix[1])) : coloringb(T, K)
137+
length(ix)==1 ? coloringv(fx(ix[1])) : coloringb(T, K)
138138
end
139139
end
140140

src/networks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ struct Coloring{K,CT<:EinTypes} <: GraphProblem
7575
end
7676
Coloring{K}(code::ET) where {K,ET<:EinTypes} = Coloring{K,ET}(code)
7777
# same network layout as independent set.
78-
Coloring(g::SimpleGraph; outputs=(), kwargs...) = Coloring(Independence(g; outputs=outputs, kwargs...).code)
78+
Coloring{K}(g::SimpleGraph; outputs=(), kwargs...) where K = Coloring{K}(Independence(g; outputs=outputs, kwargs...).code)
7979

8080
"""
8181
labels(code)

test/bitvector.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using GraphTensorNetworks: statictrues, staticfalses, StaticBitVector, onehotv
44
@testset "static bit vector" begin
55
@test statictrues(StaticBitVector{3,1}) == trues(3)
66
@test staticfalses(StaticBitVector{3,1}) == falses(3)
7+
@test_throws BoundsError statictrues(StaticBitVector{3,1})[4]
8+
#@test (@inbounds statictrues(StaticBitVector{3,1})[4]) == 0
79
x = rand(Bool, 131)
810
y = rand(Bool, 131)
911
a = StaticBitVector(x)
@@ -14,5 +16,12 @@ using GraphTensorNetworks: statictrues, staticfalses, StaticBitVector, onehotv
1416
@test op(a, b) == op.(a2, b2)
1517
end
1618
@test onehotv(StaticBitVector{133,3}, 5) == (x = falses(133); x[5]=true; x)
19+
@test [StaticElementVector(3, [3,1,0,1])...] == [3,1,0,1]
20+
bl = rand(1:3,100)
21+
@test [StaticElementVector(3, bl)...] == bl
22+
bl = rand(1:15,100)
23+
xl = StaticElementVector(16, bl)
24+
@test typeof(xl) == StaticElementVector{100,4,7}
25+
@test [xl...] == bl
1726
end
1827

0 commit comments

Comments
 (0)