|
1 | 1 | export is_commutative_semiring
|
2 |
| -export Max2Poly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod, ConfigEnumerator, onehotv, ConfigSampler |
| 2 | +export Max2Poly, TruncatedPoly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod, ConfigEnumerator, onehotv, ConfigSampler |
3 | 3 | export set_type, sampler_type
|
4 | 4 |
|
5 | 5 | using Polynomials: Polynomial
|
@@ -60,48 +60,76 @@ function is_commutative_semiring(a::T, b::T, c::T) where T
|
60 | 60 | end
|
61 | 61 |
|
62 | 62 | # get maximum two countings (polynomial truncated to largest two orders)
|
63 |
| -struct Max2Poly{T,TO} <: Number |
64 |
| - a::T |
65 |
| - b::T |
| 63 | +struct TruncatedPoly{K,T,TO} <: Number |
| 64 | + coeffs::NTuple{K,T} |
66 | 65 | maxorder::TO
|
67 | 66 | end
|
| 67 | +const Max2Poly{T,TO} = TruncatedPoly{2,T,TO} |
| 68 | +Max2Poly(a, b, maxorder) = TruncatedPoly((a, b), maxorder) |
68 | 69 |
|
69 | 70 | function Base.:+(a::Max2Poly, b::Max2Poly)
|
| 71 | + aa, ab = a.coeffs |
| 72 | + ba, bb = b.coeffs |
70 | 73 | if a.maxorder == b.maxorder
|
71 |
| - return Max2Poly(a.a+b.a, a.b+b.b, a.maxorder) |
| 74 | + return Max2Poly(aa+ba, ab+bb, a.maxorder) |
72 | 75 | elseif a.maxorder == b.maxorder-1
|
73 |
| - return Max2Poly(a.b+b.a, b.b, b.maxorder) |
| 76 | + return Max2Poly(ab+ba, bb, b.maxorder) |
74 | 77 | elseif a.maxorder == b.maxorder+1
|
75 |
| - return Max2Poly(a.a+b.b, a.b, a.maxorder) |
| 78 | + return Max2Poly(aa+bb, ab, a.maxorder) |
76 | 79 | elseif a.maxorder < b.maxorder
|
77 | 80 | return b
|
78 | 81 | else
|
79 | 82 | return a
|
80 | 83 | end
|
81 | 84 | end
|
82 | 85 |
|
| 86 | +function Base.:+(a::TruncatedPoly{K}, b::TruncatedPoly{K}) where K |
| 87 | + if a.maxorder == b.maxorder |
| 88 | + return TruncatedPoly(a.coeffs .+ b.coeffs, a.maxorder) |
| 89 | + elseif a.maxorder > b.maxorder |
| 90 | + offset = a.maxorder - b.maxorder |
| 91 | + return TruncatedPoly(ntuple(i->i+offset <= K ? a.coeffs[i] + b.coeffs[i+offset] : a.coeffs[i], K), a.maxorder) |
| 92 | + else |
| 93 | + offset = b.maxorder - a.maxorder |
| 94 | + return TruncatedPoly(ntuple(i->i+offset <= K ? b.coeffs[i] + a.coeffs[i+offset] : b.coeffs[i], K), b.maxorder) |
| 95 | + end |
| 96 | +end |
| 97 | + |
83 | 98 | function Base.:*(a::Max2Poly, b::Max2Poly)
|
84 | 99 | maxorder = a.maxorder + b.maxorder
|
85 |
| - Max2Poly(a.a*b.b + a.b*b.a, a.b * b.b, maxorder) |
| 100 | + aa, ab = a.coeffs |
| 101 | + ba, bb = b.coeffs |
| 102 | + Max2Poly(aa*bb + ab*ba, ab * bb, maxorder) |
86 | 103 | end
|
87 | 104 |
|
88 |
| -Base.zero(::Type{Max2Poly{T,TO}}) where {T,TO} = Max2Poly(zero(T), zero(T), zero(Tropical{TO}).n) |
89 |
| -Base.one(::Type{Max2Poly{T,TO}}) where {T,TO} = Max2Poly(zero(T), one(T), zero(TO)) |
90 |
| -Base.zero(::Max2Poly{T,TO}) where {T,TO} = zero(Max2Poly{T,TO}) |
91 |
| -Base.one(::Max2Poly{T,TO}) where {T,TO} = one(Max2Poly{T,TO}) |
92 |
| - |
93 |
| -Base.show(io::IO, x::Max2Poly) = show(io, MIME"text/plain"(), x) |
94 |
| -function Base.show(io::IO, ::MIME"text/plain", x::Max2Poly) |
| 105 | +function Base.:*(a::TruncatedPoly{K,T}, b::TruncatedPoly{K,T}) where {K,T} |
| 106 | + maxorder = a.maxorder + b.maxorder |
| 107 | + TruncatedPoly(ntuple(K) do k |
| 108 | + r = zero(T) |
| 109 | + for i=1:K-k+1 |
| 110 | + r += a.coeffs[i+k-1]*b.coeffs[K-i+1] |
| 111 | + end |
| 112 | + return r |
| 113 | + end, maxorder) |
| 114 | +end |
| 115 | + |
| 116 | +Base.zero(::Type{TruncatedPoly{K,T,TO}}) where {K,T,TO} = TruncatedPoly(ntuple(i->zero(T), K), zero(Tropical{TO}).n) |
| 117 | +Base.one(::Type{TruncatedPoly{K,T,TO}}) where {K,T,TO} = TruncatedPoly(ntuple(i->i==K ? one(T) : zero(T), K), zero(TO)) |
| 118 | +Base.zero(::TruncatedPoly{K,T,TO}) where {K,T,TO} = zero(TruncatedPoly{T,TO}) |
| 119 | +Base.one(::TruncatedPoly{K,T,TO}) where {K,T,TO} = one(TruncatedPoly{T,TO}) |
| 120 | + |
| 121 | +Base.show(io::IO, x::TruncatedPoly) = show(io, MIME"text/plain"(), x) |
| 122 | +function Base.show(io::IO, ::MIME"text/plain", x::TruncatedPoly{K}) where K |
95 | 123 | if isinf(x.maxorder)
|
96 | 124 | print(io, 0)
|
97 | 125 | else
|
98 |
| - printpoly(io, Polynomial([x.a, x.b], :x), offset=Int(x.maxorder-1)) |
| 126 | + printpoly(io, Polynomial([x.coeffs...], :x), offset=Int(x.maxorder-K+1)) |
99 | 127 | end
|
100 | 128 | end
|
101 | 129 |
|
102 | 130 | # patch for CUDA matmul
|
103 |
| -Base.:*(a::Bool, y::Max2Poly{T,TO}) where {T,TO} = a ? y : zero(y) |
104 |
| -Base.:*(y::Max2Poly{T,TO}, a::Bool) where {T,TO} = a ? y : zero(y) |
| 131 | +Base.:*(a::Bool, y::TruncatedPoly{K,T,TO}) where {K,T,TO} = a ? y : zero(y) |
| 132 | +Base.:*(y::TruncatedPoly{K,T,TO}, a::Bool) where {K,T,TO} = a ? y : zero(y) |
105 | 133 |
|
106 | 134 | struct ConfigEnumerator{N,S,C}
|
107 | 135 | data::Vector{StaticElementVector{N,S,C}}
|
|
0 commit comments