You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
returnnil, errors.Errorf("Expected indices to be a vector. Got %v instead", indices.Shape())
13
17
}
14
-
ifb.Dtype() !=Int {
15
-
returnnil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", b.Dtype())
18
+
ifindices.Dtype() !=Int {
19
+
returnnil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", indices.Dtype())
16
20
}
17
21
18
22
// if b is a scalar, then use Slice
19
23
ifa.Shape().IsScalarEquiv() {
20
24
slices:=make([]Slice, a.Shape().Dims())
21
-
slices[axis] =ss(b.Data().([]int)[0])
25
+
slices[axis] =ss(getInts(indices)[0])
22
26
returna.Slice(slices...)
23
27
}
24
28
25
29
expectedShape:=a.Shape().Clone()
26
-
expectedShape[axis] =b.Shape().TotalSize()
30
+
expectedShape[axis] =indices.Shape().TotalSize()
27
31
28
32
varreuseDenseTensor
29
33
varsafe, toReuse, _bool
@@ -36,9 +40,9 @@ func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal
36
40
}
37
41
38
42
if!safe {
39
-
ifa.Shape()[axis] !=b.Shape().TotalSize() {
43
+
ifa.Shape()[axis] !=indices.Shape().TotalSize() {
40
44
expected:=a.Shape().Clone()
41
-
expected[axis] =b.Shape().TotalSize()
45
+
expected[axis] =indices.Shape().TotalSize()
42
46
returnnil, errors.Errorf("Expected a safe resuse to have the same shape as the expected shape of the result: %v. The input a has %v ", expected, a.Shape())
43
47
}
44
48
@@ -49,7 +53,7 @@ func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal
0 commit comments