Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Optimize 'take' operator for CPU (#20745)
Browse files Browse the repository at this point in the history
* Improve performance of take operator

* remove comment

* Fix build

* fix sanity

* Add comment

* review

* Update src/operator/tensor/indexing_op.h

Co-authored-by: bartekkuncer <[email protected]>

Co-authored-by: Sheng Zha <[email protected]>
Co-authored-by: bartekkuncer <[email protected]>
  • Loading branch information
3 people authored Jan 19, 2022
1 parent 7d84b59 commit 69e6c04
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 32 deletions.
99 changes: 69 additions & 30 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,51 @@ struct TakeZeroAxisCPU {
}
};

template <bool clip = true>
struct TakeNonzeroAxisCPU {
/*!
* \brief Map function for take operator
* \param i global thread id
* \param out_data ptr to output buffer
* \param in_data ptr to input buffer
* \param indices ptr to indices buffer
* \param outer_dim_stride stride of dimension before axis
* \param axis_dim_stride stride of axis dimension
* \param idx_size size of the indices tensor
* \param axis_dim dim size of the axis dimension
* \param axis axis id
*/
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i,
DType* out_data,
const DType* in_data,
const IType* indices,
const index_t outer_dim_stride,
const index_t axis_dim_stride,
const int idx_size,
const int axis_dim,
const int axis) {
for (index_t j = 0; j < static_cast<index_t>(idx_size); ++j) {
int index = indices[j];
if (clip) {
index = std::max(index, 0);
index = std::min(axis_dim - 1, index);
} else {
index %= axis_dim;
index += (index < 0) ? axis_dim : 0;
}
size_t in_offset = i * outer_dim_stride + index * axis_dim_stride;
size_t out_offset = (i * idx_size + j) * axis_dim_stride;
#pragma GCC diagnostic push
#if __GNUC__ >= 8
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#endif
std::memcpy(out_data + out_offset, in_data + in_offset, axis_dim_stride * sizeof(DType));
#pragma GCC diagnostic pop
}
}
};

/*
* \brief returns true if all indices are between [min, max]
* \param data_ptr the indices to check
Expand Down Expand Up @@ -323,6 +368,7 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;

if (req[take_::kOut] == kNullOp)
return;
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
Expand Down Expand Up @@ -375,39 +421,32 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) {
in_strides[i] = stride;
}
mshadow::Shape<10> out_strides;
stride = 1;
for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) {
out_strides[i] = stride;
int outer_dimensions = 1;
for (int i = 0; i < actual_axis; i++) {
outer_dimensions *= oshape[i];
}
if (param.mode == take_::kClip) {
Kernel<TakeNonzeroAxis<true>, cpu>::Launch(s,
oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
out_strides[actual_axis - 1],
in_strides[actual_axis - 1],
in_strides[actual_axis],
arrshape.ndim(),
oshape.ndim(),
idxshape.ndim(),
arrshape[actual_axis],
actual_axis);
Kernel<TakeNonzeroAxisCPU<true>, cpu>::Launch(s,
outer_dimensions,
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides[actual_axis - 1],
in_strides[actual_axis],
idxshape.Size(),
arrshape[actual_axis],
actual_axis);
} else {
Kernel<TakeNonzeroAxis<false>, cpu>::Launch(s,
oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
out_strides[actual_axis - 1],
in_strides[actual_axis - 1],
in_strides[actual_axis],
arrshape.ndim(),
oshape.ndim(),
idxshape.ndim(),
arrshape[actual_axis],
actual_axis);
Kernel<TakeNonzeroAxisCPU<false>, cpu>::Launch(s,
outer_dimensions,
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides[actual_axis - 1],
in_strides[actual_axis],
idxshape.Size(),
arrshape[actual_axis],
actual_axis);
}
}
});
Expand Down
5 changes: 3 additions & 2 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,9 @@ inline bool EmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
return dispatched;
}

/*! \brief name the struct TakeNonzeroAxis for general take when
* axis is not zero, use TakeZeroAxisGPU or TakeZeroAxisCPU for axis zero
/*! \brief TakeNonzeroAxis is designated for general take when
* axis is not zero (for CPU optimized version use TakeNonZeroAxisCPU and
for axis zero use TakeZeroAxisGPU or TakeZeroAxisCPU)
*/
template <bool clip = true>
struct TakeNonzeroAxis {
Expand Down

0 comments on commit 69e6c04

Please sign in to comment.