Skip to content

Commit 5c64ff6

Browse files
committed
update
1 parent d4ef1ac commit 5c64ff6

File tree

4 files changed

+37
-36
lines changed

4 files changed

+37
-36
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Compose = "0.9"
2828
FFTW = "1.4"
2929
LightGraphs = "1.3"
3030
Mods = "1.3"
31-
OMEinsum = "0.4, 0.5"
31+
OMEinsum = "0.6.1"
3232
OMEinsumContractionOrders = "0.5"
3333
Polynomials = "2.0"
3434
Primes = "0.5"

src/GraphTensorNetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using OMEinsumContractionOrders: OMEinsum
44
using Core: Argument
55
using TropicalGEMM, TropicalNumbers
66
using OMEinsum
7-
using OMEinsum: timespace_complexity
7+
using OMEinsum: timespace_complexity, collect_ixs
88
using LightGraphs
99

1010
export timespace_complexity, @ein_str

src/bounding.jl

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using TupleTools
2+
using OMEinsum: DynamicEinCode
23

34
"""
45
backward_tropical(mode, ixs, xs, iy, y, ymask, size_dict)
@@ -10,12 +11,12 @@ The backward rule for tropical einsum.
1011
* `ymask` is the boolean mask for gradients,
1112
* `size_dict` is a key-value map from tensor label to dimension size.
1213
"""
13-
function backward_tropical(mode, @nospecialize(ixs), @nospecialize(xs), @nospecialize(iy), @nospecialize(y), @nospecialize(ymask), size_dict)
14+
function backward_tropical(mode, ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), @nospecialize(ymask), size_dict)
1415
y .= inv.(y) .* ymask
1516
masks = []
1617
for i=1:length(ixs)
17-
nixs = TupleTools.insertat(ixs, i, (iy,))
18-
nxs = TupleTools.insertat( xs, i, (y,))
18+
nixs = OMEinsum._insertat(ixs, i, iy)
19+
nxs = OMEinsum._insertat( xs, i, y)
1920
niy = ixs[i]
2021
if mode == :all
2122
mask = zeros(Bool, size(xs[i]))
@@ -53,34 +54,39 @@ struct CacheTree{T}
5354
content::AbstractArray{T}
5455
siblings::Vector{CacheTree{T}}
5556
end
56-
function cached_einsum(code::Int, @nospecialize(xs), size_dict)
57-
y = xs[code]
58-
CacheTree(y, CacheTree{eltype(y)}[])
59-
end
6057
function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
61-
caches = [cached_einsum(arg, xs, size_dict) for arg in code.args]
62-
y = code.eins(getfield.(caches, :content)...; size_info=size_dict)
63-
CacheTree(y, caches)
58+
if OMEinsum.isleaf(code)
59+
y = xs[code.tensorindex]
60+
return CacheTree(y, CacheTree{eltype(y)}[])
61+
else
62+
caches = [cached_einsum(arg, xs, size_dict) for arg in code.args]
63+
y = einsum(code.eins, ntuple(i->caches[i].content, length(caches)), size_dict)
64+
return CacheTree(y, caches)
65+
end
6466
end
6567

6668
# computed mask tree by back propagation
67-
function generate_masktree(code::Int, cache, mask, size_dict, mode=:all)
68-
CacheTree(mask, CacheTree{Bool}[])
69-
end
7069
function generate_masktree(code::NestedEinsum, cache, mask, size_dict, mode=:all)
71-
submasks = backward_tropical(mode, getixs(code.eins), (getfield.(cache.siblings, :content)...,), OMEinsum.getiy(code.eins), cache.content, mask, size_dict)
72-
return CacheTree(mask, generate_masktree.(code.args, cache.siblings, submasks, Ref(size_dict), mode))
70+
if OMEinsum.isleaf(code)
71+
return CacheTree(mask, CacheTree{Bool}[])
72+
else
73+
submasks = backward_tropical(mode, getixs(code.eins), (getfield.(cache.siblings, :content)...,), OMEinsum.getiy(code.eins), cache.content, mask, size_dict)
74+
return CacheTree(mask, generate_masktree.(code.args, cache.siblings, submasks, Ref(size_dict), mode))
75+
end
7376
end
7477

7578
# The masked einsum contraction
76-
function masked_einsum(code::Int, @nospecialize(xs), masks, size_dict)
77-
y = copy(xs[code])
78-
y[OMEinsum.asarray(.!masks.content)] .= Ref(zero(eltype(y))); y
79-
end
8079
function masked_einsum(code::NestedEinsum, @nospecialize(xs), masks, size_dict)
81-
xs = [masked_einsum(arg, xs, mask, size_dict) for (arg, mask) in zip(code.args, masks.siblings)]
82-
y = einsum(code.eins, (xs...,), size_dict)
83-
y[OMEinsum.asarray(.!masks.content)] .= Ref(zero(eltype(y))); y
80+
if OMEinsum.isleaf(code)
81+
y = copy(xs[code.tensorindex])
82+
y[OMEinsum.asarray(.!masks.content)] .= Ref(zero(eltype(y)))
83+
return y
84+
else
85+
xs = [masked_einsum(arg, xs, mask, size_dict) for (arg, mask) in zip(code.args, masks.siblings)]
86+
y = einsum(code.eins, (xs...,), size_dict)
87+
y[OMEinsum.asarray(.!masks.content)] .= Ref(zero(eltype(y)))
88+
return y
89+
end
8490
end
8591

8692
"""
@@ -92,8 +98,9 @@ Contraction method with bounding.
9298
* `xsb` are input tensors for computing, e.g. tensors elements are counting tropical with set algebra,
9399
* `ymask` is the initial gradient mask for the output tensor.
94100
"""
95-
function bounding_contract(@nospecialize(code::EinCode), @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
96-
bounding_contract(NestedEinsum((1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info)
101+
function bounding_contract(code::EinCode, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
102+
LT = OMEinsum.labeltype(code)
103+
bounding_contract(NestedEinsum(NestedEinsum{DynamicEinCode{LT}}.(1:length(xsa)), code), xsa, ymask, xsb; size_info=size_info)
97104
end
98105
function bounding_contract(code::NestedEinsum, @nospecialize(xsa), ymask, @nospecialize(xsb); size_info=nothing)
99106
size_dict = size_info===nothing ? Dict{OMEinsum.labeltype(code.eins),Int}() : copy(size_info)
@@ -109,8 +116,9 @@ function bounding_contract(code::NestedEinsum, @nospecialize(xsa), ymask, @nospe
109116
end
110117

111118
# get the optimal solution with automatic differentiation.
112-
function solution_ad(@nospecialize(code::EinCode), @nospecialize(xsa), ymask; size_info=nothing)
113-
solution_ad(NestedEinsum((1:length(xsa)), code), xsa, ymask; size_info=size_info)
119+
function solution_ad(code::EinCode, @nospecialize(xsa), ymask; size_info=nothing)
120+
LT = OMEinsum.labeltype(code)
121+
solution_ad(NestedEinsum(NestedEinsum{DynamicEinCode{LT}}.(1:length(xsa)), code), xsa, ymask; size_info=size_info)
114122
end
115123

116124
function solution_ad(code::NestedEinsum, @nospecialize(xsa), ymask; size_info=nothing)
@@ -128,7 +136,7 @@ end
128136

129137
function read_config!(code::NestedEinsum, mt, out)
130138
for (arg, ix, sibling) in zip(code.args, getixs(code.eins), mt.siblings)
131-
if arg isa Int
139+
if OMEinsum.isleaf(arg)
132140
assign = convert(Array, sibling.content) # note: the content can be CuArray
133141
if length(ix) == 1
134142
if !assign[1] && assign[2]

src/networks.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,6 @@ function labels(code::EinTypes)
100100
return res
101101
end
102102

103-
collect_ixs(ne::EinCode) = [collect(ix) for ix in getixs(ne)]
104-
function collect_ixs(ne::NestedEinsum)
105-
d = OMEinsum.collect_ixs!(ne, Dict{Int,Vector{OMEinsum.labeltype(ne.eins)}}())
106-
ks = sort!(collect(keys(d)))
107-
return @inbounds [d[i] for i in ks]
108-
end
109-
110103
"""
111104
optimize_code(code; optmethod=:kahypar, sc_target=17, max_group_size=40, nrepeat=10, imbalances=0.0:0.001:0.8, βs=0.01:0.05:10.0, ntrials=50, niters=1000, sc_weight=2.0, rw_weight=1.0)
112105

0 commit comments

Comments
 (0)