Skip to content

Commit 3e2949e

Browse files
committed
update truncated poly
1 parent fe48be8 commit 3e2949e

File tree

6 files changed

+92
-37
lines changed

6 files changed

+92
-37
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1919
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2020
TropicalGEMM = "a4ad3063-64a7-4bad-8738-34ed09bc0236"
2121
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
22-
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
2322
Viznet = "52a3aca4-6234-47fd-b74a-806bdf78ede9"
2423

2524
[compat]
@@ -35,7 +34,6 @@ Primes = "0.5"
3534
Requires = "1"
3635
TropicalGEMM = "0.1"
3736
TropicalNumbers = "0.4, 0.5"
38-
TupleTools = "1.2"
3937
Viznet = "0.3"
4038
julia = "1"
4139

src/arithematics.jl

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export is_commutative_semiring
2-
export Max2Poly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod, ConfigEnumerator, onehotv, ConfigSampler
2+
export Max2Poly, TruncatedPoly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod, ConfigEnumerator, onehotv, ConfigSampler
33
export set_type, sampler_type
44

55
using Polynomials: Polynomial
@@ -60,48 +60,76 @@ function is_commutative_semiring(a::T, b::T, c::T) where T
6060
end
6161

6262
# get maximum two countings (polynomial truncated to largest two orders)
63-
struct Max2Poly{T,TO} <: Number
64-
a::T
65-
b::T
63+
struct TruncatedPoly{K,T,TO} <: Number
64+
coeffs::NTuple{K,T}
6665
maxorder::TO
6766
end
67+
const Max2Poly{T,TO} = TruncatedPoly{2,T,TO}
68+
Max2Poly(a, b, maxorder) = TruncatedPoly((a, b), maxorder)
6869

6970
function Base.:+(a::Max2Poly, b::Max2Poly)
71+
aa, ab = a.coeffs
72+
ba, bb = b.coeffs
7073
if a.maxorder == b.maxorder
71-
return Max2Poly(a.a+b.a, a.b+b.b, a.maxorder)
74+
return Max2Poly(aa+ba, ab+bb, a.maxorder)
7275
elseif a.maxorder == b.maxorder-1
73-
return Max2Poly(a.b+b.a, b.b, b.maxorder)
76+
return Max2Poly(ab+ba, bb, b.maxorder)
7477
elseif a.maxorder == b.maxorder+1
75-
return Max2Poly(a.a+b.b, a.b, a.maxorder)
78+
return Max2Poly(aa+bb, ab, a.maxorder)
7679
elseif a.maxorder < b.maxorder
7780
return b
7881
else
7982
return a
8083
end
8184
end
8285

86+
function Base.:+(a::TruncatedPoly{K}, b::TruncatedPoly{K}) where K
87+
if a.maxorder == b.maxorder
88+
return TruncatedPoly(a.coeffs .+ b.coeffs, a.maxorder)
89+
elseif a.maxorder > b.maxorder
90+
offset = a.maxorder - b.maxorder
91+
return TruncatedPoly(ntuple(i->i+offset <= K ? a.coeffs[i] + b.coeffs[i+offset] : a.coeffs[i], K), a.maxorder)
92+
else
93+
offset = b.maxorder - a.maxorder
94+
return TruncatedPoly(ntuple(i->i+offset <= K ? b.coeffs[i] + a.coeffs[i+offset] : b.coeffs[i], K), b.maxorder)
95+
end
96+
end
97+
8398
function Base.:*(a::Max2Poly, b::Max2Poly)
8499
maxorder = a.maxorder + b.maxorder
85-
Max2Poly(a.a*b.b + a.b*b.a, a.b * b.b, maxorder)
100+
aa, ab = a.coeffs
101+
ba, bb = b.coeffs
102+
Max2Poly(aa*bb + ab*ba, ab * bb, maxorder)
86103
end
87104

88-
Base.zero(::Type{Max2Poly{T,TO}}) where {T,TO} = Max2Poly(zero(T), zero(T), zero(Tropical{TO}).n)
89-
Base.one(::Type{Max2Poly{T,TO}}) where {T,TO} = Max2Poly(zero(T), one(T), zero(TO))
90-
Base.zero(::Max2Poly{T,TO}) where {T,TO} = zero(Max2Poly{T,TO})
91-
Base.one(::Max2Poly{T,TO}) where {T,TO} = one(Max2Poly{T,TO})
92-
93-
Base.show(io::IO, x::Max2Poly) = show(io, MIME"text/plain"(), x)
94-
function Base.show(io::IO, ::MIME"text/plain", x::Max2Poly)
105+
function Base.:*(a::TruncatedPoly{K,T}, b::TruncatedPoly{K,T}) where {K,T}
106+
maxorder = a.maxorder + b.maxorder
107+
TruncatedPoly(ntuple(K) do k
108+
r = zero(T)
109+
for i=1:K-k+1
110+
r += a.coeffs[i+k-1]*b.coeffs[K-i+1]
111+
end
112+
return r
113+
end, maxorder)
114+
end
115+
116+
Base.zero(::Type{TruncatedPoly{K,T,TO}}) where {K,T,TO} = TruncatedPoly(ntuple(i->zero(T), K), zero(Tropical{TO}).n)
117+
Base.one(::Type{TruncatedPoly{K,T,TO}}) where {K,T,TO} = TruncatedPoly(ntuple(i->i==K ? one(T) : zero(T), K), zero(TO))
118+
Base.zero(::TruncatedPoly{K,T,TO}) where {K,T,TO} = zero(TruncatedPoly{T,TO})
119+
Base.one(::TruncatedPoly{K,T,TO}) where {K,T,TO} = one(TruncatedPoly{T,TO})
120+
121+
Base.show(io::IO, x::TruncatedPoly) = show(io, MIME"text/plain"(), x)
122+
function Base.show(io::IO, ::MIME"text/plain", x::TruncatedPoly{K}) where K
95123
if isinf(x.maxorder)
96124
print(io, 0)
97125
else
98-
printpoly(io, Polynomial([x.a, x.b], :x), offset=Int(x.maxorder-1))
126+
printpoly(io, Polynomial([x.coeffs...], :x), offset=Int(x.maxorder-K+1))
99127
end
100128
end
101129

102130
# patch for CUDA matmul
103-
Base.:*(a::Bool, y::Max2Poly{T,TO}) where {T,TO} = a ? y : zero(y)
104-
Base.:*(y::Max2Poly{T,TO}, a::Bool) where {T,TO} = a ? y : zero(y)
131+
Base.:*(a::Bool, y::TruncatedPoly{K,T,TO}) where {K,T,TO} = a ? y : zero(y)
132+
Base.:*(y::TruncatedPoly{K,T,TO}, a::Bool) where {K,T,TO} = a ? y : zero(y)
105133

106134
struct ConfigEnumerator{N,S,C}
107135
data::Vector{StaticElementVector{N,S,C}}

src/bounding.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using TupleTools
21
using OMEinsum: DynamicEinCode
32

43
"""

test/arithematics.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using GraphTensorNetworks, Test, OMEinsum, OMEinsumContractionOrders
2+
using Mods, Polynomials, TropicalNumbers
3+
using LightGraphs, Random
4+
using GraphTensorNetworks: StaticBitVector
5+
6+
@testset "truncated poly" begin
7+
p1 = TruncatedPoly((2,2,1), 2.0)
8+
p2 = TruncatedPoly((2,3,9), 4.0)
9+
x = Polynomial([2, 2, 1])
10+
y = Polynomial([0, 0, 2, 3, 9])
11+
r1 = p1 + p2
12+
r2 = p2 + p1
13+
r3 = x + y
14+
@test r1.coeffs == r2.coeffs == (r3.coeffs[end-2:end]...,)
15+
q1 = p1 * p2
16+
q2 = p2 * p1
17+
q3 = x * y
18+
@test q1.coeffs == q2.coeffs == (q3.coeffs[end-2:end]...,)
19+
r1 = p1 + p1
20+
r3 = x + x
21+
@test r1.coeffs == (r3.coeffs[end-2:end]...,)
22+
r1 = p1 * p1
23+
r3 = x * x
24+
@test r1.coeffs == (r3.coeffs[end-2:end]...,)
25+
end
26+
27+
@testset "arithematics" begin
28+
for (a, b, c) in [
29+
(TropicalF64(2), TropicalF64(8), TropicalF64(9)),
30+
(CountingTropicalF64(2, 8), CountingTropicalF64(8, 9), CountingTropicalF64(9, 2)),
31+
(Mod{17}(2), Mod{17}(8), Mod{17}(9)),
32+
(Polynomial([0,1,2,3.0]), Polynomial([3,2.0]), Polynomial([1,7.0])),
33+
(Max2Poly(1,2,3.0), Max2Poly(3,2,2.0), Max2Poly(4,7,1.0)),
34+
(TruncatedPoly((1,2,3),3.0), TruncatedPoly((7,3,2),2.0), TruncatedPoly((1,4,7),1.0)),
35+
(TropicalF64(5), TropicalF64(3), TropicalF64(-9)),
36+
(CountingTropicalF64(5, 3), CountingTropicalF64(3, 9), CountingTropicalF64(-3, 2)),
37+
(CountingTropical(5.0, ConfigSampler(StaticBitVector(rand(Bool, 10)))), CountingTropical(3.0, ConfigSampler(StaticBitVector(rand(Bool, 10)))), CountingTropical(-3.0, ConfigSampler(StaticBitVector(rand(Bool, 10))))),
38+
(CountingTropical(5.0, ConfigEnumerator([StaticBitVector(rand(Bool, 10)) for j=1:3])), CountingTropical(3.0, ConfigEnumerator([StaticBitVector(rand(Bool, 10)) for j=1:4])), CountingTropical(-3.0, ConfigEnumerator([StaticBitVector(rand(Bool, 10)) for j=1:5]))),
39+
]
40+
@test is_commutative_semiring(a, b, c)
41+
end
42+
end

test/graph_polynomials.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,6 @@ end
2828
@test p1 p4
2929
end
3030

31-
@testset "arithematics" begin
32-
for (a, b, c) in [
33-
(TropicalF64(2), TropicalF64(8), TropicalF64(9)),
34-
(CountingTropicalF64(2, 8), CountingTropicalF64(8, 9), CountingTropicalF64(9, 2)),
35-
(Mod{17}(2), Mod{17}(8), Mod{17}(9)),
36-
(Polynomial([0,1,2,3.0]), Polynomial([3,2.0]), Polynomial([1,7.0])),
37-
(Max2Poly(1,2,3.0), Max2Poly(3,2,2.0), Max2Poly(4,7,1.0)),
38-
(TropicalF64(5), TropicalF64(3), TropicalF64(-9)),
39-
(CountingTropicalF64(5, 3), CountingTropicalF64(3, 9), CountingTropicalF64(-3, 2)),
40-
(CountingTropical(5.0, ConfigSampler(StaticBitVector(rand(Bool, 10)))), CountingTropical(3.0, ConfigSampler(StaticBitVector(rand(Bool, 10)))), CountingTropical(-3.0, ConfigSampler(StaticBitVector(rand(Bool, 10))))),
41-
(CountingTropical(5.0, ConfigEnumerator([StaticBitVector(rand(Bool, 10)) for j=1:3])), CountingTropical(3.0, ConfigEnumerator([StaticBitVector(rand(Bool, 10)) for j=1:4])), CountingTropical(-3.0, ConfigEnumerator([StaticBitVector(rand(Bool, 10)) for j=1:5]))),
42-
]
43-
@test is_commutative_semiring(a, b, c)
44-
end
45-
end
46-
4731
@testset "counting maximal IS" begin
4832
g = random_regular_graph(20, 3)
4933
gp = MaximalIndependence(g, optimizer=KaHyParBipartite(sc_target=20))

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ using Test
55
include("bitvector.jl")
66
end
77

8+
@testset "arithematics" begin
9+
include("arithematics.jl")
10+
end
11+
812
@testset "independence polynomial" begin
913
include("graph_polynomials.jl")
1014
end

0 commit comments

Comments
 (0)