1
1
using OMEinsum: DynamicEinCode
2
2
3
+ struct AllConfigs{K} end
4
+ largest_k (:: AllConfigs{K} ) where K = K
5
+ struct SingleConfig end
6
+
3
7
"""
4
8
backward_tropical(mode, ixs, xs, iy, y, ymask, size_dict)
5
9
@@ -17,11 +21,11 @@ function backward_tropical(mode, ixs, @nospecialize(xs::Tuple), iy, @nospecializ
17
21
nixs = OMEinsum. _insertat (ixs, i, iy)
18
22
nxs = OMEinsum. _insertat ( xs, i, y)
19
23
niy = ixs[i]
20
- if mode == :all
24
+ if mode isa AllConfigs
21
25
mask = zeros (Bool, size (xs[i]))
22
- mask .= inv .(einsum (EinCode (nixs, niy), nxs, size_dict)) .== xs[i]
26
+ mask .= inv .(einsum (EinCode (nixs, niy), nxs, size_dict)) .<= xs[i] .* Tropical ( largest_k (mode) - 1 )
23
27
push! (masks, mask)
24
- elseif mode == :single # wrong, need `B` matching `A`.
28
+ elseif mode isa SingleConfig
25
29
A = zeros (eltype (xs[i]), size (xs[i]))
26
30
A = einsum (EinCode (nixs, niy), nxs, size_dict)
27
31
push! (masks, onehotmask (A, xs[i]))
@@ -65,12 +69,12 @@ function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
65
69
end
66
70
67
71
# computed mask tree by back propagation
68
- function generate_masktree (code:: NestedEinsum , cache, mask, size_dict, mode = :all )
72
+ function generate_masktree (mode, code:: NestedEinsum , cache, mask, size_dict)
69
73
if OMEinsum. isleaf (code)
70
74
return CacheTree (mask, CacheTree{Bool}[])
71
75
else
72
76
submasks = backward_tropical (mode, getixs (code. eins), (getfield .(cache. siblings, :content )... ,), OMEinsum. getiy (code. eins), cache. content, mask, size_dict)
73
- return CacheTree (mask, generate_masktree .(code. args, cache. siblings, submasks, Ref (size_dict), mode ))
77
+ return CacheTree (mask, generate_masktree .(Ref (mode), code. args, cache. siblings, submasks, Ref (size_dict)))
74
78
end
75
79
end
76
80
@@ -89,27 +93,28 @@ function masked_einsum(code::NestedEinsum, @nospecialize(xs), masks, size_dict)
89
93
end
90
94
91
95
"""
92
- bounding_contract(code, xsa, ymask, xsb; size_info=nothing)
96
+ bounding_contract(mode, code, xsa, ymask, xsb; size_info=nothing)
93
97
94
98
Contraction method with bounding.
95
99
100
+ * `mode` is a `AllConfigs{K}` instance, where `MIS-K+1` is the largest IS size that you care about.
96
101
* `xsa` are input tensors for bounding, e.g. tropical tensors,
97
102
* `xsb` are input tensors for computing, e.g. tensors elements are counting tropical with set algebra,
98
103
* `ymask` is the initial gradient mask for the output tensor.
99
104
"""
100
- function bounding_contract (code:: EinCode , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
105
+ function bounding_contract (mode :: AllConfigs , code:: EinCode , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
101
106
LT = OMEinsum. labeltype (code)
102
- bounding_contract (NestedEinsum (NestedEinsum {DynamicEinCode{LT}} .(1 : length (xsa)), code), xsa, ymask, xsb; size_info= size_info)
107
+ bounding_contract (mode, NestedEinsum (NestedEinsum {DynamicEinCode{LT}} .(1 : length (xsa)), code), xsa, ymask, xsb; size_info= size_info)
103
108
end
104
- function bounding_contract (code:: NestedEinsum , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
109
+ function bounding_contract (mode :: AllConfigs , code:: NestedEinsum , @nospecialize (xsa), ymask, @nospecialize (xsb); size_info= nothing )
105
110
size_dict = size_info=== nothing ? Dict {OMEinsum.labeltype(code.eins),Int} () : copy (size_info)
106
111
OMEinsum. get_size_dict! (code, xsa, size_dict)
107
112
# compute intermediate tensors
108
113
@debug " caching einsum..."
109
114
c = cached_einsum (code, xsa, size_dict)
110
115
# compute masks from cached tensors
111
116
@debug " generating masked tree..."
112
- mt = generate_masktree (code, c, ymask, size_dict, :all )
117
+ mt = generate_masktree (mode, code, c, ymask, size_dict)
113
118
# compute results with masks
114
119
masked_einsum (code, xsb, mt, size_dict)
115
120
end
@@ -129,7 +134,7 @@ function solution_ad(code::NestedEinsum, @nospecialize(xsa), ymask; size_info=no
129
134
n = asscalar (c. content)
130
135
# compute masks from cached tensors
131
136
@debug " generating masked tree..."
132
- mt = generate_masktree (code, c, ymask, size_dict, :single )
137
+ mt = generate_masktree (SingleConfig (), code, c, ymask, size_dict)
133
138
n, read_config! (code, mt, Dict ())
134
139
end
135
140
0 commit comments