Skip to content

Commit 2d23774

Browse files
committed
rm permutedims patch
1 parent d0e1594 commit 2d23774

File tree

2 files changed

+2
-33
lines changed

2 files changed

+2
-33
lines changed

src/GraphTensorNetworks.jl

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,6 @@ export GreedyMethod, TreeSA, SABipartite, KaHyParBipartite, MergeVectors, MergeG
1212

1313
project_relative_path(xs...) = normpath(joinpath(dirname(dirname(pathof(@__MODULE__))), xs...))
1414

15-
# patch to permutedims
16-
using Base.Cartesian
17-
using Base: size_to_strides, checkdims_perm
18-
for (V, PT, BT) in Any[((:N,), BitArray, BitArray), ((:T,:N), Array, StridedArray)]
19-
@eval @generated function Base.permutedims!(P::$PT{$(V...)}, B::$BT{$(V...)}, perm) where $(V...)
20-
quote
21-
checkdims_perm(P, B, perm)
22-
23-
#calculates all the strides
24-
native_strides = size_to_strides(1, size(B)...)
25-
strides_1 = 0
26-
@nexprs $N d->(strides_{d+1} = native_strides[perm[d]])
27-
28-
#Creates offset, because indexing starts at 1
29-
offset = 1 - sum(@ntuple $N d->strides_{d+1})
30-
31-
sumc = 0
32-
ind = 1
33-
@nexprs 1 d->(counts_{$N+1} = strides_{$N+1}) # a trick to set counts_($N+1)
34-
@nloops($N, i, P,
35-
d->(df_d=i_d*strides_{d+1} ;sumc += df_d), # PRE
36-
d->(sumc -= df_d), # POST
37-
begin # BODY
38-
@inbounds P[ind] = B[sumc+offset]
39-
ind += 1
40-
end)
41-
42-
return P
43-
end
44-
end
45-
end
46-
4715
include("bitvector.jl")
4816
include("arithematics.jl")
4917
include("networks.jl")

src/arithematics.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ export set_type, sampler_type
55
using Polynomials: Polynomial
66
using TropicalNumbers: Tropical, CountingTropical
77
using Mods, Primes
8+
using Base.Cartesian
89

910
# pirate
1011
Base.abs(x::Mod) = x
@@ -221,4 +222,4 @@ end
221222
onehotv(::Type{ConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigEnumerator([onehotv(StaticElementVector{N,S,C}, i, v)])
222223
onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampler(onehotv(StaticElementVector{N,S,C}, i, v))
223224
Base.transpose(c::ConfigEnumerator) = c
224-
Base.copy(c::ConfigEnumerator) = ConfigEnumerator(copy(c.data))
225+
Base.copy(c::ConfigEnumerator) = ConfigEnumerator(copy(c.data))

0 commit comments

Comments
 (0)