Skip to content

Commit 75dc16c

Browse files
authored
configs for max-k configurations (#28)
* use tropical in extended tropical * done
1 parent ec8869f commit 75dc16c

File tree

7 files changed

+114
-65
lines changed

7 files changed

+114
-65
lines changed

examples/IndependentSet.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ independence_polynomial = solve(problem, GraphPolynomial(; method=:finitefield))
9292
# ### Configuration properties
9393
# ##### finding one maximum independent set (MIS)
9494
# One can use [`SingleConfigMax`](@ref) to find one of the solution with largest set size, and it has two implementations.
95-
# The unbounded (default) version uses [`ConfigSampler`](@ref) to sample one of the best solutions directly.
95+
# The unbounded (default) version uses a joint type of [`CountingTropical`](@ref) and [`ConfigSampler`](@ref) in computation,
96+
# where `CountingTropical` finds the maximum size and `ConfigSampler` samples one of the best solutions.
9697
# The bounded version uses the binary gradient back-propagation (see our paper) to compute the gradients.
9798
# It requires caching intermediate states, but is often faster (on CPU) because it can use [`TropicalGEMM`](https://github.com/TensorBFS/TropicalGEMM.jl) (see [Performance Tips](@ref)).
9899
max_config = solve(problem, SingleConfigMax(; bounded=false))[]
@@ -172,6 +173,18 @@ show_graph(graph; locs=locations, vertex_colors=
172173
spectrum = solve(problem, SizeMax(10))
173174

174175
# It uses the [`ExtendedTropical`](@ref) as the tensor elements.
176+
# One can get sets with maximum `K` sizes, by combining [`ExtendedTropical`](@ref) and the algebra in the previous section for sampling one configuration.
177+
max5_configs = solve(problem, SingleConfigMax(5))[]
178+
179+
imgs_max5 = ntuple(k->show_graph(graph;
180+
locs=locations, scale=0.25,
181+
vertex_colors=[iszero(max5_configs.orders[k].c.data[i]) ? "white" : "red"
182+
for i=1:nv(graph)]), 5);
183+
184+
Compose.set_default_graphic_size(18cm, 4cm)
185+
186+
Compose.compose(context(),
187+
ntuple(k->(context((k-1)/5, 0.0, 1.2/5, 1.0), imgs_max5[k]), 5)...)
175188

176189
# ## Open vertices and MIS tensor analysis
177190
# The following code computes the MIS tropical tensor (reference to be added) with open vertices 1, 2 and 3.

src/arithematics.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ end
175175
176176
Extended Tropical numbers with largest `K` orders keeped,
177177
or the [`TruncatedPoly`](@ref) without coefficients,
178-
`TO` is the element type of orders.
178+
`TO` is the element type of orders, usually [`Tropical`](@ref) numbers.
179179
This algebra maps
180180
181181
* `+` to finding largest `K` values of union of two sets.
@@ -186,23 +186,23 @@ This algebra maps
186186
Example
187187
------------------------------
188188
```jldoctest; setup=(using GraphTensorNetworks)
189-
julia> x = ExtendedTropical{3}([1.0, 2, 3])
190-
ExtendedTropical{3, Float64}([1.0, 2.0, 3.0])
189+
julia> x = ExtendedTropical{3}(Tropical.([1.0, 2, 3]))
190+
ExtendedTropical{3, TropicalF64}(TropicalF64[1.0ₜ, 2.0ₜ, 3.0ₜ])
191191
192-
julia> y = ExtendedTropical{3}([-Inf, 2, 5])
193-
ExtendedTropical{3, Float64}([-Inf, 2.0, 5.0])
192+
julia> y = ExtendedTropical{3}(Tropical.([-Inf, 2, 5]))
193+
ExtendedTropical{3, TropicalF64}(TropicalF64[-Infₜ, 2.0ₜ, 5.0ₜ])
194194
195195
julia> x * y
196-
ExtendedTropical{3, Float64}([6.0, 7.0, 8.0])
196+
ExtendedTropical{3, TropicalF64}(TropicalF64[6.0ₜ, 7.0ₜ, 8.0ₜ])
197197
198198
julia> x + y
199-
ExtendedTropical{3, Float64}([2.0, 3.0, 5.0])
199+
ExtendedTropical{3, TropicalF64}(TropicalF64[2.0ₜ, 3.0ₜ, 5.0ₜ])
200200
201201
julia> one(x)
202-
ExtendedTropical{3, Float64}([-Inf, -Inf, 0.0])
202+
ExtendedTropical{3, TropicalF64}(TropicalF64[-Infₜ, -Infₜ, 0.0ₜ])
203203
204204
julia> zero(x)
205-
ExtendedTropical{3, Float64}([-Inf, -Inf, -Inf])
205+
ExtendedTropical{3, TropicalF64}(TropicalF64[-Infₜ, -Infₜ, -Infₜ])
206206
```
207207
"""
208208
struct ExtendedTropical{K,TO} <: Number
@@ -223,10 +223,10 @@ end
223223
function sorted_sum_combination!(res::AbstractVector{TO}, A::AbstractVector{TO}, B::AbstractVector{TO}) where TO
224224
K = length(res)
225225
@assert length(B) == length(A) == K
226-
maxval = A[K] + B[K]
226+
maxval = A[K] * B[K]
227227
ptr = K
228228
res[ptr] = maxval
229-
queue = [(K,K-1,A[K]+B[K-1]), (K-1,K,A[K-1]+B[K])]
229+
queue = [(K,K-1,A[K]*B[K-1]), (K-1,K,A[K-1]*B[K])]
230230
for k = 1:K-1
231231
(i, j, res[K-k]) = _pop_max_sum!(queue) # TODO: do not enumerate, use better data structures
232232
_push_if_not_exists!(queue, i, j-1, A, B)
@@ -237,7 +237,7 @@ end
237237

238238
function _push_if_not_exists!(queue, i, j, A, B)
239239
@inbounds if j>=1 && i>=1 && !any(x->x[1] >= i && x[2] >= j, queue)
240-
push!(queue, (i, j, A[i] + B[j]))
240+
push!(queue, (i, j, A[i]*B[j]))
241241
end
242242
end
243243

@@ -281,14 +281,14 @@ end
281281
Base.:^(a::ExtendedTropical, b::Integer) = Base.invoke(^, Tuple{ExtendedTropical, Real}, a, b)
282282
function Base.:^(a::ExtendedTropical{K,TO}, b::Real) where {K,TO}
283283
if iszero(b) # to avoid NaN
284-
return one(ExtendedTropical{K,promote_type(TO,typeof(b))})
284+
return one(ExtendedTropical{K,TO})
285285
else
286-
return ExtendedTropical{K,TO}(a.orders .* b)
286+
return ExtendedTropical{K,TO}(a.orders .^ b)
287287
end
288288
end
289289

290-
Base.zero(::Type{ExtendedTropical{K,TO}}) where {K,TO} = ExtendedTropical{K,TO}(fill(zero(Tropical{TO}).n, K))
291-
Base.one(::Type{ExtendedTropical{K,TO}}) where {K,TO} = ExtendedTropical{K,TO}(map(i->i==K ? one(Tropical{TO}).n : zero(Tropical{TO}).n, 1:K))
290+
Base.zero(::Type{ExtendedTropical{K,TO}}) where {K,TO} = ExtendedTropical{K,TO}(fill(zero(TO), K))
291+
Base.one(::Type{ExtendedTropical{K,TO}}) where {K,TO} = ExtendedTropical{K,TO}(map(i->i==K ? one(TO) : zero(TO), 1:K))
292292
Base.zero(::ExtendedTropical{K,TO}) where {K,TO} = zero(ExtendedTropical{K,TO})
293293
Base.one(::ExtendedTropical{K,TO}) where {K,TO} = one(ExtendedTropical{K,TO})
294294

@@ -703,6 +703,7 @@ for (F,TP) in [(:set_type, :ConfigEnumerator), (:sampler_type, :ConfigSampler),
703703
end
704704
end
705705
end
706+
sampler_type(::Type{ExtendedTropical{K,T}}, n::Int, nflavor::Int) where {K,T} = ExtendedTropical{K, sampler_type(T, n, nflavor)}
706707

707708
# utilities for creating onehot vectors
708709
onehotv(::Type{ConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigEnumerator([onehotv(StaticElementVector{N,S,C}, i, v)])
@@ -774,8 +775,7 @@ function _x(::Type{Tropical{TV}}; invert) where {TV}
774775
invert ? pre_invert_exponent(ret) : ret
775776
end
776777
function _x(::Type{ExtendedTropical{K,TO}}; invert) where {K,TO}
777-
ret =ExtendedTropical{K,TO}(map(i->i==K ? one(TO) : zero(Tropical{TO}).n, 1:K))
778-
invert ? pre_invert_exponent(ret) : ret
778+
return ExtendedTropical{K,TO}(map(i->i==K ? _x(TO; invert=invert) : zero(TO), 1:K))
779779
end
780780

781781
# for finding all solutions
@@ -796,12 +796,14 @@ end
796796
function _onehotv(::Type{BS}, x, v) where {BS<:AbstractSetNumber}
797797
onehotv(BS, x, v)
798798
end
799+
function _onehotv(::Type{ExtendedTropical{K,TO}}, x, v) where {K,T,BS<:AbstractSetNumber,TO<:CountingTropical{T,BS}}
800+
ExtendedTropical{K,TO}(map(i->i==K ? _onehotv(TO, x, v) : zero(TO), 1:K))
801+
end
799802

800803
# negate the exponents before entering the solver
801804
pre_invert_exponent(t::TruncatedPoly{K}) where K = TruncatedPoly(t.coeffs, -t.maxorder)
802805
pre_invert_exponent(t::TropicalNumbers.TropicalTypes) = inv(t)
803-
pre_invert_exponent(t::ExtendedTropical{K}) where K = ExtendedTropical{K}(map(i->i==K ? -t.orders[i] : t.orders[i], 1:K))
804806
# negate the exponents after entering the solver
805807
post_invert_exponent(t::TruncatedPoly{K}) where K = TruncatedPoly(ntuple(i->t.coeffs[K-i+1], K), -t.maxorder+(K-1))
806808
post_invert_exponent(t::TropicalNumbers.TropicalTypes) = inv(t)
807-
post_invert_exponent(t::ExtendedTropical{K}) where K = ExtendedTropical{K}(map(i->-t.orders[i], K:-1:1))
809+
post_invert_exponent(t::ExtendedTropical{K}) where K = ExtendedTropical{K}(map(i->inv(t.orders[i]), K:-1:1))

src/interfaces.jl

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,10 @@ GraphPolynomial(; method::Symbol = :finitefield, kwargs...) = GraphPolynomial{me
109109
graph_polynomial_method(::GraphPolynomial{METHOD}) where METHOD = METHOD
110110

111111
"""
112-
SingleConfigMax{BOUNDED} <: AbstractProperty
113-
SingleConfigMax(; bounded=false)
112+
SingleConfigMax{K, BOUNDED} <: AbstractProperty
113+
SingleConfigMax(k::Int; bounded=false)
114114
115-
Finding single best solution, e.g. for [`IndependentSet`](@ref) problem, it is one of the maximum independent sets.
115+
Finding single solution for largest-K sizes, e.g. for [`IndependentSet`](@ref) problem, it is one of the maximum independent sets.
116116
117117
* The corresponding data type is [`CountingTropical{Float64,<:ConfigSampler}`](@ref) if `BOUNDED` is `false`, [`Tropical`](@ref) otherwise.
118118
* Weighted graph problems is supported.
@@ -122,14 +122,15 @@ Keyword Arguments
122122
----------------------------
123123
* `bounded`, if it is true, use bounding trick (or boolean gradients) to reduce the working memory to store intermediate configurations.
124124
"""
125-
struct SingleConfigMax{BOUNDED} <:AbstractProperty end
126-
SingleConfigMax(; bounded::Bool=false) = SingleConfigMax{bounded}()
125+
struct SingleConfigMax{K,BOUNDED} <:AbstractProperty end
126+
SingleConfigMax(k::Int=1; bounded::Bool=false) = SingleConfigMax{k, bounded}()
127+
max_k(::SingleConfigMax{K}) where K = K
127128

128129
"""
129-
SingleConfigMin{BOUNDED} <: AbstractProperty
130-
SingleConfigMin(; bounded=false)
130+
SingleConfigMin{K, BOUNDED} <: AbstractProperty
131+
SingleConfigMin(k::Int; bounded=false)
131132
132-
Finding single "worst" solution.
133+
Finding single solution with smallest-K size.
133134
134135
* The corresponding data type is inverted [`CountingTropical{Float64,<:ConfigSampler}`](@ref) if `BOUNDED` is `false`, inverted [`Tropical`](@ref) otherwise.
135136
* Weighted graph problems is supported.
@@ -139,8 +140,9 @@ Keyword Arguments
139140
----------------------------
140141
* `bounded`, if it is true, use bounding trick (or boolean gradients) to reduce the working memory to store intermediate configurations.
141142
"""
142-
struct SingleConfigMin{BOUNDED} <:AbstractProperty end
143-
SingleConfigMin(; bounded::Bool=false) = SingleConfigMin{bounded}()
143+
struct SingleConfigMin{K,BOUNDED} <:AbstractProperty end
144+
SingleConfigMin(k::Int=1; bounded::Bool=false) = SingleConfigMin{k,bounded}()
145+
min_k(::SingleConfigMin{K}) where K = K
144146

145147
"""
146148
ConfigsAll{TREESTORAGE} <:AbstractProperty
@@ -216,7 +218,7 @@ Positional Arguments
216218
* [`CountingAll`](@ref) for counting all configurations,
217219
* [`GraphPolynomial`](@ref) for evaluating the graph polynomial,
218220
219-
* [`SingleConfigMax`](@ref) for finding one maximum configuration,
221+
* [`SingleConfigMax`](@ref) for finding one maximum configuration for each size,
220222
* [`ConfigsMax`](@ref) for enumerating configurations with largest-K sizes,
221223
* [`ConfigsMin`](@ref) for enumerating configurations with smallest-K sizes,
222224
* [`ConfigsAll`](@ref) for enumerating all configurations,
@@ -236,9 +238,9 @@ function solve(gp::GraphProblem, property::AbstractProperty; T=Float64, usecuda=
236238
elseif property isa SizeMin{1}
237239
return post_invert_exponent.(contractx(gp, _x(Tropical{T}; invert=true); usecuda=usecuda))
238240
elseif property isa SizeMax
239-
return contractx(gp, _x(ExtendedTropical{max_k(property), T}; invert=false); usecuda=usecuda)
241+
return contractx(gp, _x(ExtendedTropical{max_k(property), Tropical{T}}; invert=false); usecuda=usecuda)
240242
elseif property isa SizeMin
241-
return post_invert_exponent.(contractx(gp, _x(ExtendedTropical{max_k(property), T}; invert=true); usecuda=usecuda))
243+
return post_invert_exponent.(contractx(gp, _x(ExtendedTropical{max_k(property), Tropical{T}}; invert=true); usecuda=usecuda))
242244
elseif property isa CountingAll
243245
return contractx(gp, one(T); usecuda=usecuda)
244246
elseif property isa CountingMax{1}
@@ -251,10 +253,14 @@ function solve(gp::GraphProblem, property::AbstractProperty; T=Float64, usecuda=
251253
return post_invert_exponent.(contractx(gp, pre_invert_exponent(TruncatedPoly(ntuple(i->i == min_k(property) ? one(T) : zero(T), min_k(property)), one(T))); usecuda=usecuda))
252254
elseif property isa GraphPolynomial
253255
return graph_polynomial(gp, Val(graph_polynomial_method(property)); usecuda=usecuda, T=T, property.kwargs...)
254-
elseif property isa SingleConfigMax{false}
255-
return solutions(gp, CountingTropical{T,T}; all=false, usecuda=usecuda, )
256-
elseif property isa SingleConfigMin{false}
256+
elseif property isa SingleConfigMax{1,false}
257+
return solutions(gp, CountingTropical{T,T}; all=false, usecuda=usecuda)
258+
elseif property isa (SingleConfigMax{K,false} where K)
259+
return solutions(gp, ExtendedTropical{max_k(property),CountingTropical{T,T}}; all=false, usecuda=usecuda)
260+
elseif property isa SingleConfigMin{1,false}
257261
return solutions(gp, CountingTropical{T,T}; all=false, usecuda=usecuda, invert=true)
262+
elseif property isa (SingleConfigMin{K,false} where K)
263+
return solutions(gp, ExtendedTropical{min_k(property),CountingTropical{T,T}}; all=false, usecuda=usecuda, invert=true)
258264
elseif property isa ConfigsMax{1,false}
259265
return solutions(gp, CountingTropical{T,T}; all=true, usecuda=usecuda, tree_storage=tree_storage(property))
260266
elseif property isa ConfigsMin{1,false}
@@ -265,10 +271,16 @@ function solve(gp::GraphProblem, property::AbstractProperty; T=Float64, usecuda=
265271
return solutions(gp, TruncatedPoly{min_k(property),T,T}; all=true, usecuda=usecuda, invert=true)
266272
elseif property isa ConfigsAll
267273
return solutions(gp, Real; all=true, usecuda=usecuda, tree_storage=tree_storage(property))
268-
elseif property isa SingleConfigMax{true}
274+
elseif property isa SingleConfigMax{1,true}
269275
return best_solutions(gp; all=false, usecuda=usecuda, T=T)
270-
elseif property isa SingleConfigMin{true}
276+
elseif property isa (SingleConfigMax{K,true} where K)
277+
@warn "bounded `SingleConfigMax` property for `K != 1` is not implemented. Switching to the unbounded version."
278+
return solve(gp, SingleConfigMax{max_k(property),false}(); T, usecuda)
279+
elseif property isa SingleConfigMin{1,true}
271280
return best_solutions(gp; all=false, usecuda=usecuda, invert=true, T=T)
281+
elseif property isa (SingleConfigMin{K,true} where K)
282+
@warn "bounded `SingleConfigMin` property for `K != 1` is not implemented. Switching to the unbounded version."
283+
return solve(gp, SingleConfigMin{min_k(property),false}(); T, usecuda)
272284
elseif property isa ConfigsMax{1,true}
273285
return best_solutions(gp; all=true, usecuda=usecuda, tree_storage=tree_storage(property), T=T)
274286
elseif property isa ConfigsMin{1,true}
@@ -396,10 +408,20 @@ Memory estimation in number of bytes to compute certain `property` of a `problem
396408
function estimate_memory(problem::GraphProblem, property::AbstractProperty; T=Float64)
397409
_estimate_memory(tensor_element_type(T, length(labels(problem)), nflavor(problem), property), problem)
398410
end
399-
function estimate_memory(problem::GraphProblem, ::Union{SingleConfigMax{true},SingleConfigMin{true}}; T=Float64)
411+
function estimate_memory(problem::GraphProblem, property::Union{SingleConfigMax{K,BOUNDED},SingleConfigMin{K,BOUNDED}}; T=Float64) where {K, BOUNDED}
400412
tc, sc, rw = timespacereadwrite_complexity(problem.code, _size_dict(problem))
401413
# caching all tensors is equivalent to counting the total number of writes
402-
return ceil(Int, exp2(rw - 1)) * sizeof(Tropical{T})
414+
if K == 1 && BOUNDED
415+
return ceil(Int, exp2(rw - 1)) * sizeof(Tropical{T})
416+
elseif K == 1 & !BOUNDED
417+
n, nf = length(labels(problem)), nflavor(problem)
418+
return peak_memory(problem.code, _size_dict(problem)) * (sizeof(tensor_element_type(T, n, nf, property)) * K)
419+
else
420+
# NOTE: the `K > 1` case does not respect bounding
421+
n, nf = length(labels(problem)), nflavor(problem)
422+
TT = tensor_element_type(T, n, nf, property)
423+
return peak_memory(problem.code, _size_dict(problem)) * (sizeof(tensor_element_type(T, n, nf, SingleConfigMax{1,BOUNDED}())) * K + sizeof(TT))
424+
end
403425
end
404426
function estimate_memory(problem::GraphProblem, ::GraphPolynomial{:polynomial}; T=Float64)
405427
# this is the upper bound
@@ -424,8 +446,7 @@ end
424446

425447
for (PROP, ET) in [
426448
(:(SizeMax{1}), :(Tropical{T})), (:(SizeMin{1}), :(Tropical{T})),
427-
(:(SizeMax{K}), :(ExtendedTropical{K,T})), (:(SizeMin{K}), :(ExtendedTropical{K,T})),
428-
(:(SingleConfigMax{true}), :(Tropical{T})), (:(SingleConfigMin{true}), :(Tropical{T})),
449+
(:(SizeMax{K}), :(ExtendedTropical{K,Tropical{T}})), (:(SizeMin{K}), :(ExtendedTropical{K,Tropical{T}})),
429450
(:(CountingAll), :T), (:(CountingMax{1}), :(CountingTropical{T,T})), (:(CountingMin{1}), :(CountingTropical{T,T})),
430451
(:(CountingMax{K}), :(TruncatedPoly{K,T,T})), (:(CountingMin{K}), :(TruncatedPoly{K,T,T})),
431452
(:(GraphPolynomial{:finitefield}), :(Mod{N,Int32} where N)), (:(GraphPolynomial{:fft}), :(Complex{T})),
@@ -434,11 +455,17 @@ for (PROP, ET) in [
434455
@eval tensor_element_type(::Type{T}, n::Int, nflavor::Int, ::$PROP) where {T,K} = $ET
435456
end
436457

437-
for (PROP, ET) in [(:(SingleConfigMax{false}), :(CountingTropical{T,T})), (:(SingleConfigMin{false}), :(CountingTropical{T,T}))]
438-
@eval function tensor_element_type(::Type{T}, n::Int, nflavor::Int, ::$PROP) where {T}
439-
sampler_type($ET, n, nflavor)
458+
function tensor_element_type(::Type{T}, n::Int, nflavor::Int, ::PROP) where {T, K, BOUNDED, PROP<:Union{SingleConfigMax{K,BOUNDED},SingleConfigMin{K,BOUNDED}}}
459+
if K == 1 && BOUNDED
460+
return Tropical{T}
461+
elseif K == 1 && !BOUNDED
462+
return sampler_type(CountingTropical{T,T}, n, nflavor)
463+
else
464+
# NOTE: the `K > 1` case does not respect bounding
465+
return sampler_type(ExtendedTropical{K,CountingTropical{T,T}}, n, nflavor)
440466
end
441467
end
468+
442469
for (PROP, ET) in [
443470
(:(ConfigsMax{1}), :(CountingTropical{T,T})), (:(ConfigsMin{1}), :(CountingTropical{T,T})),
444471
(:(ConfigsMax{K}), :(TruncatedPoly{K,T,T})), (:(ConfigsMin{K}), :(TruncatedPoly{K,T,T})),

src/networks/networks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ end
127127
add_labels!(tensors::AbstractVector{<:AbstractArray}, ixs, labels) = tensors
128128

129129
const SetPolyNumbers{T} = Union{Polynomial{T}, TruncatedPoly{K,T} where K, CountingTropical{TV,T} where TV} where T<:AbstractSetNumber
130-
function add_labels!(tensors::AbstractVector{<:AbstractArray{T}}, ixs, labels) where T <: Union{AbstractSetNumber, SetPolyNumbers}
130+
function add_labels!(tensors::AbstractVector{<:AbstractArray{T}}, ixs, labels) where T <: Union{AbstractSetNumber, SetPolyNumbers, ExtendedTropical{K,T} where {K,T<:SetPolyNumbers}}
131131
for (t, ix) in zip(tensors, ixs)
132132
for (dim, l) in enumerate(ix)
133133
index = findfirst(==(l), labels)

0 commit comments

Comments
 (0)