Skip to content

Commit ffcb1c7

Browse files
authored
Optimizations (gorgonia#121)
* fix a bug in ByIdx * Added comments and documentation to softmax * Renamed the variables in SelectByIndicesB and SelectByIndices for better clarity. * Added example of ByIndices * Removed the allocation in `SoftMax` and `SoftMaxB` as suggested by @dcu * Switch to using `getFloat64s` and `getFloat32s` (new utility func) to reduce allocations ``` Results vs prev: benchmark old ns/op new ns/op delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 2237 2057 -8.05% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 2138 1920 -10.20% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 2112 1798 -14.87% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 2123 1844 -13.14% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 2236 1937 -13.37% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 2305 2040 -11.50% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 2167 1931 -10.89% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 2261 1884 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 2119 2035 -3.96% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 2143 1846 -13.86% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 2212 1821 -17.68% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 2164 1930 -10.81% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 36898948 36137745 -2.06% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 35541861 35019509 -1.47% benchmark old allocs new allocs delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 16 12 -25.00% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 16 12 -25.00% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 16 12 -25.00% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 16 12 -25.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 16 12 -25.00% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 17 13 -23.53% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 17 13 -23.53% benchmark old bytes new bytes delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 664 568 -14.46% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 616 520 -15.58% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 664 568 -14.46% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 616 520 -15.58% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 696 600 -13.79% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 648 552 -14.81% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 696 600 -13.79% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 648 552 -14.81% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 696 600 -13.79% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 648 552 -14.81% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 696 600 -13.79% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 648 552 -14.81% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 19392926 19392912 -0.00% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 9701448 9701351 -0.00% ``` * `SoftMax` and `SoftMaxB` optimization: Removed unnecessary calls to .Clone() which reduces the number of allocs. Results: ``` benchmark old ns/op new ns/op delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 2057 1619 -21.29% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 1920 1563 -18.59% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 1798 1508 -16.13% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 1844 1575 -14.59% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 1937 1836 -5.21% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 2040 1672 -18.04% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 1931 1704 -11.76% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 1884 1542 -18.15% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 2035 1558 -23.44% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 1846 1626 -11.92% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 1821 1552 -14.77% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 1930 1499 -22.33% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 36137745 36795574 +1.82% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 35019509 34759423 -0.74% benchmark old allocs new allocs delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 12 10 -16.67% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 12 10 -16.67% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 12 10 -16.67% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 12 10 -16.67% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 12 10 -16.67% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 13 11 -15.38% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 13 11 -15.38% benchmark old bytes new bytes delta BenchmarkSoftmax/(3,4)_Float64_axis_0-20 568 528 -7.04% BenchmarkSoftmax/(3,4)_Float32_axis_0-20 520 480 -7.69% BenchmarkSoftmax/(3,4)_Float64_axis_1-20 568 528 -7.04% BenchmarkSoftmax/(3,4)_Float32_axis_1-20 520 480 -7.69% BenchmarkSoftmax/(2,3,2)_Float64_axis_0-20 600 552 -8.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_0-20 552 504 -8.70% BenchmarkSoftmax/(2,3,2)_Float64_axis_1-20 600 552 -8.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_1-20 552 504 -8.70% BenchmarkSoftmax/(2,3,2)_Float64_axis_2-20 600 552 -8.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_2-20 552 504 -8.70% BenchmarkSoftmax/(2,3,2)_Float64_axis_-1-20 600 552 -8.00% BenchmarkSoftmax/(2,3,2)_Float32_axis_-1-20 552 504 -8.70% BenchmarkSoftmax/(641,19,199)_Float64_axis_-1-20 19392912 19392892 -0.00% BenchmarkSoftmax/(641,_19,_199)_Float32_axis_-1-20 9701351 9701312 -0.00% ``` * Parallelized the softmax code
1 parent 3039f42 commit ffcb1c7

8 files changed

+487
-285
lines changed

defaultengine_selbyidx.go

+33-26
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,27 @@ import (
77
"reflect"
88
)
99

10-
func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
11-
if !b.Shape().IsVectorLike() {
12-
return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", b.Shape())
10+
// SelectByIndices selects the values given the in `indices` tensor.
11+
//
12+
// Currently SelectByIndices only supports Dense tensors that do not require the use of iterators.
13+
// Please make a pull request to support tensors that require the use of an iterator to traverse data.
14+
func (e StdEng) SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
15+
if !indices.Shape().IsVectorLike() {
16+
return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", indices.Shape())
1317
}
14-
if b.Dtype() != Int {
15-
return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", b.Dtype())
18+
if indices.Dtype() != Int {
19+
return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", indices.Dtype())
1620
}
1721

1822
// if b is a scalar, then use Slice
1923
if a.Shape().IsScalarEquiv() {
2024
slices := make([]Slice, a.Shape().Dims())
21-
slices[axis] = ss(b.Data().([]int)[0])
25+
slices[axis] = ss(getInts(indices)[0])
2226
return a.Slice(slices...)
2327
}
2428

2529
expectedShape := a.Shape().Clone()
26-
expectedShape[axis] = b.Shape().TotalSize()
30+
expectedShape[axis] = indices.Shape().TotalSize()
2731

2832
var reuse DenseTensor
2933
var safe, toReuse, _ bool
@@ -36,9 +40,9 @@ func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal
3640
}
3741

3842
if !safe {
39-
if a.Shape()[axis] != b.Shape().TotalSize() {
43+
if a.Shape()[axis] != indices.Shape().TotalSize() {
4044
expected := a.Shape().Clone()
41-
expected[axis] = b.Shape().TotalSize()
45+
expected[axis] = indices.Shape().TotalSize()
4246
return nil, errors.Errorf("Expected a safe resuse to have the same shape as the expected shape of the result: %v. The input a has %v ", expected, a.Shape())
4347
}
4448

@@ -49,7 +53,7 @@ func (e StdEng) SelectByIndices(a, b Tensor, axis int, opts ...FuncOpt) (retVal
4953
var dataA, dataB, dataReuse *storage.Header
5054
var ait, bit, iit Iterator
5155
var useIter bool
52-
if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, b, reuse); err != nil {
56+
if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, indices, reuse); err != nil {
5357
return nil, errors.Wrapf(err, "StdEng.Add")
5458
}
5559

@@ -130,39 +134,42 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da
130134
}
131135
}
132136

133-
// SelectByIndicesB is the backwards function of SelectByIndices.
134-
func (e StdEng) SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
137+
// SelectByIndicesB computes the gradient of the result of `SelectByIndices`.
138+
//
139+
// Currently SelectByIndicesB only supports Dense tensors that do not require the use of iterators.
140+
// Please make a pull request to support tensors that require the use of an iterator to traverse data.
141+
func (e StdEng) SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
135142
if !indices.Shape().IsVectorLike() {
136-
return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", b.Shape())
143+
return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", outGrad.Shape())
137144
}
138145
if indices.Dtype() != Int {
139-
return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", b.Dtype())
146+
return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", outGrad.Dtype())
140147
}
141148

142149
// if b is a scalar, then use Slice
143-
if a.Shape().IsScalarEquiv() {
144-
slices := make([]Slice, a.Shape().Dims())
145-
slices[axis] = ss(b.Data().([]int)[0])
146-
return a.Slice(slices...)
150+
if input.Shape().IsScalarEquiv() {
151+
slices := make([]Slice, input.Shape().Dims())
152+
slices[axis] = ss(outGrad.Data().([]int)[0])
153+
return input.Slice(slices...)
147154
}
148155

149-
expectedShape := a.Shape().Clone()
156+
expectedShape := input.Shape().Clone()
150157

151158
var reuse DenseTensor
152159
var _, toReuse, _ bool
153-
if reuse, _, toReuse, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil {
160+
if reuse, _, toReuse, _, _, err = handleFuncOpts(input.Shape(), input.Dtype(), input.DataOrder(), true, opts...); err != nil {
154161
return nil, errors.Wrap(err, "Unable to handle funcOpts")
155162
}
156163
if !toReuse && reuse == nil {
157164
// create reuse
158-
reuse = New(WithShape(expectedShape...), Of(a.Dtype()))
165+
reuse = New(WithShape(expectedShape...), Of(input.Dtype()))
159166
}
160167

161-
typ := a.Dtype().Type
168+
typ := input.Dtype().Type
162169
var _, dataB, dataReuse *storage.Header
163170
var _, bit, iit Iterator
164171
var useIter bool
165-
if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(a, b, reuse); err != nil {
172+
if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(input, outGrad, reuse); err != nil {
166173
return nil, errors.Wrapf(err, "StdEng.SelectByIndicesB")
167174
}
168175

@@ -172,7 +179,7 @@ func (e StdEng) SelectByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt
172179
return
173180
}
174181

175-
e.selectByIndicesB(axis, indices.Data().([]int), typ, dataB, dataReuse, b.(*Dense).AP, reuse.(*Dense).AP)
182+
e.selectByIndicesB(axis, getInts(indices), typ, dataB, dataReuse, outGrad.(*Dense).AP, reuse.(*Dense).AP)
176183

177184
return reuse, nil
178185
}
@@ -228,8 +235,8 @@ func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, data
228235
for i, idx := range indices {
229236
dstCoord[axis] = idx
230237
srcCoord[axis] = i
231-
dstStart, _ := Ltoi(apB.shape, apB.strides, dstCoord...)
232-
start, _ := Ltoi(apRet.shape, apRet.strides, srcCoord...)
238+
dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...)
239+
start, _ := Ltoi(apB.shape, apB.strides, srcCoord...)
233240

234241
for o := 0; o < outer; o++ {
235242
dstEnd := dstStart + axStride

0 commit comments

Comments
 (0)