1
1
using TupleTools
2
+ using OMEinsum: DynamicEinCode
2
3
3
4
"""
4
5
backward_tropical(mode, ixs, xs, iy, y, ymask, size_dict)
@@ -10,12 +11,12 @@ The backward rule for tropical einsum.
10
11
* `ymask` is the boolean mask for gradients,
11
12
* `size_dict` is a key-value map from tensor label to dimension size.
12
13
"""
13
- function backward_tropical (mode, @nospecialize ( ixs) , @nospecialize (xs), @nospecialize (iy) , @nospecialize (y), @nospecialize (ymask), size_dict)
14
+ function backward_tropical (mode, ixs, @nospecialize (xs:: Tuple ), iy , @nospecialize (y), @nospecialize (ymask), size_dict)
14
15
y .= inv .(y) .* ymask
15
16
masks = []
16
17
for i= 1 : length (ixs)
17
- nixs = TupleTools . insertat (ixs, i, (iy,) )
18
- nxs = TupleTools . insertat ( xs, i, (y,) )
18
+ nixs = OMEinsum . _insertat (ixs, i, iy )
19
+ nxs = OMEinsum . _insertat ( xs, i, y )
19
20
niy = ixs[i]
20
21
if mode == :all
21
22
mask = zeros (Bool, size (xs[i]))
@@ -53,34 +54,39 @@ struct CacheTree{T}
53
54
content:: AbstractArray{T}
54
55
siblings:: Vector{CacheTree{T}}
55
56
end
56
- function cached_einsum (code:: Int , @nospecialize (xs), size_dict)
57
- y = xs[code]
58
- CacheTree (y, CacheTree{eltype (y)}[])
59
- end
60
57
function cached_einsum (code:: NestedEinsum , @nospecialize (xs), size_dict)
61
- caches = [cached_einsum (arg, xs, size_dict) for arg in code. args]
62
- y = code. eins (getfield .(caches, :content )... ; size_info= size_dict)
63
- CacheTree (y, caches)
58
+ if OMEinsum. isleaf (code)
59
+ y = xs[code. tensorindex]
60
+ return CacheTree (y, CacheTree{eltype (y)}[])
61
+ else
62
+ caches = [cached_einsum (arg, xs, size_dict) for arg in code. args]
63
+ y = einsum (code. eins, ntuple (i-> caches[i]. content, length (caches)), size_dict)
64
+ return CacheTree (y, caches)
65
+ end
64
66
end
65
67
66
68
# computed mask tree by back propagation
67
- function generate_masktree (code:: Int , cache, mask, size_dict, mode= :all )
68
- CacheTree (mask, CacheTree{Bool}[])
69
- end
70
69
function generate_masktree (code:: NestedEinsum , cache, mask, size_dict, mode= :all )
71
- submasks = backward_tropical (mode, getixs (code. eins), (getfield .(cache. siblings, :content )... ,), OMEinsum. getiy (code. eins), cache. content, mask, size_dict)
72
- return CacheTree (mask, generate_masktree .(code. args, cache. siblings, submasks, Ref (size_dict), mode))
70
+ if OMEinsum. isleaf (code)
71
+ return CacheTree (mask, CacheTree{Bool}[])
72
+ else
73
+ submasks = backward_tropical (mode, getixs (code. eins), (getfield .(cache. siblings, :content )... ,), OMEinsum. getiy (code. eins), cache. content, mask, size_dict)
74
+ return CacheTree (mask, generate_masktree .(code. args, cache. siblings, submasks, Ref (size_dict), mode))
75
+ end
73
76
end
74
77
75
78
# The masked einsum contraction
76
- function masked_einsum (code:: Int , @nospecialize (xs), masks, size_dict)
77
- y = copy (xs[code])
78
- y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y))); y
79
- end
80
79
function masked_einsum (code:: NestedEinsum , @nospecialize (xs), masks, size_dict)
81
- xs = [masked_einsum (arg, xs, mask, size_dict) for (arg, mask) in zip (code. args, masks. siblings)]
82
- y = einsum (code. eins, (xs... ,), size_dict)
83
- y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y))); y
80
+ if OMEinsum. isleaf (code)
81
+ y = copy (xs[code. tensorindex])
82
+ y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y)))
83
+ return y
84
+ else
85
+ xs = [masked_einsum (arg, xs, mask, size_dict) for (arg, mask) in zip (code. args, masks. siblings)]
86
+ y = einsum (code. eins, (xs... ,), size_dict)
87
+ y[OMEinsum. asarray (.! masks. content)] .= Ref (zero (eltype (y)))
88
+ return y
89
+ end
84
90
end
85
91
86
92
"""
@@ -92,8 +98,9 @@ Contraction method with bounding.
92
98
* `xsb` are input tensors for computing, e.g. tensors elements are counting tropical with set algebra,
93
99
* `ymask` is the initial gradient mask for the output tensor.
94
100
"""
95
- function bounding_contract (@nospecialize (code:: EinCode ), @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
96
- bounding_contract (NestedEinsum ((1 : length (xsa)), code), xsa, ymask, xsb; size_info= size_info)
101
+ function bounding_contract (code:: EinCode , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
102
+ LT = OMEinsum. labeltype (code)
103
+ bounding_contract (NestedEinsum (NestedEinsum {DynamicEinCode{LT}} .(1 : length (xsa)), code), xsa, ymask, xsb; size_info= size_info)
97
104
end
98
105
function bounding_contract (code:: NestedEinsum , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
99
106
size_dict = size_info=== nothing ? Dict {OMEinsum.labeltype(code.eins),Int} () : copy (size_info)
@@ -109,8 +116,9 @@ function bounding_contract(code::NestedEinsum, @nospecialize(xsa), ymask, @nospe
109
116
end
110
117
111
118
# get the optimal solution with automatic differentiation.
112
- function solution_ad (@nospecialize (code:: EinCode ), @nospecialize (xsa), ymask; size_info= nothing )
113
- solution_ad (NestedEinsum ((1 : length (xsa)), code), xsa, ymask; size_info= size_info)
119
+ function solution_ad (code:: EinCode , @nospecialize (xsa), ymask; size_info= nothing )
120
+ LT = OMEinsum. labeltype (code)
121
+ solution_ad (NestedEinsum (NestedEinsum {DynamicEinCode{LT}} .(1 : length (xsa)), code), xsa, ymask; size_info= size_info)
114
122
end
115
123
116
124
function solution_ad (code:: NestedEinsum , @nospecialize (xsa), ymask; size_info= nothing )
128
136
129
137
function read_config! (code:: NestedEinsum , mt, out)
130
138
for (arg, ix, sibling) in zip (code. args, getixs (code. eins), mt. siblings)
131
- if arg isa Int
139
+ if OMEinsum . isleaf ( arg)
132
140
assign = convert (Array, sibling. content) # note: the content can be CuArray
133
141
if length (ix) == 1
134
142
if ! assign[1 ] && assign[2 ]
0 commit comments