Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 160 additions & 138 deletions onnxruntime/core/providers/cpu/quantization/conv_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,161 +20,183 @@ class ConvInteger : public OpKernel {
Status Compute(OpKernelContext* context) const override;

ConvAttributes conv_attrs_;
};

ONNX_OPERATOR_KERNEL_EX(
ConvInteger,
kOnnxDomain,
10,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
ConvInteger);

Status ConvInteger::Compute(OpKernelContext* context) const {
const auto input_defs = Node().InputDefs();
size_t num_inputs = input_defs.size();
const auto* X = context->Input<Tensor>(0);
const auto* W = context->Input<Tensor>(1);
uint8_t input_offset = 0;
uint8_t filter_offset = 0;
if (num_inputs >= 3 && input_defs[2]->Exists()) {
const auto* X_Zero_Point = context->Input<Tensor>(2);
ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1.");
input_offset = *(X_Zero_Point->Data<uint8_t>());
}
if (num_inputs >= 4 && input_defs[3]->Exists()) {
const auto* W_Zero_Point = context->Input<Tensor>(3);
ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now.");
filter_offset = *(W_Zero_Point->Data<uint8_t>());
}
private:
template <typename XT, typename WT>
Status ComputeInner(OpKernelContext* context) const {
const auto input_defs = Node().InputDefs();
size_t num_inputs = input_defs.size();
const auto* X = context->Input<Tensor>(0);
const auto* W = context->Input<Tensor>(1);
uint8_t input_offset = 0;
uint8_t filter_offset = 0;
if (num_inputs >= 3 && input_defs[2]->Exists()) {
const auto* X_Zero_Point = context->Input<Tensor>(2);
ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1.");
input_offset = *static_cast<const uint8_t*>(X_Zero_Point->DataRaw());
}
if (num_inputs >= 4 && input_defs[3]->Exists()) {
const auto* W_Zero_Point = context->Input<Tensor>(3);
ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now.");
filter_offset = *static_cast<const uint8_t*>(W_Zero_Point->DataRaw());
}

const int64_t N = X->Shape()[0];
const int64_t C = X->Shape()[1];
const int64_t M = W->Shape()[0];
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W));
const int64_t N = X->Shape()[0];
const int64_t C = X->Shape()[1];
const int64_t M = W->Shape()[0];
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W));

TensorShapeVector kernel_shape;
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape));
TensorShapeVector kernel_shape;
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape));

ConvPadVector pads(conv_attrs_.pads);
if (pads.empty()) {
pads.resize(kernel_shape.size() * 2, 0);
}
TensorShapeVector dilations(conv_attrs_.dilations);
if (dilations.empty()) {
dilations.resize(kernel_shape.size(), 1);
}
TensorShapeVector strides(conv_attrs_.strides);
if (strides.empty()) {
strides.resize(kernel_shape.size(), 1);
}
ConvPadVector pads(conv_attrs_.pads);
if (pads.empty()) {
pads.resize(kernel_shape.size() * 2, 0);
}
TensorShapeVector dilations(conv_attrs_.dilations);
if (dilations.empty()) {
dilations.resize(kernel_shape.size(), 1);
}
TensorShapeVector strides(conv_attrs_.strides);
if (strides.empty()) {
strides.resize(kernel_shape.size(), 1);
}

TensorShapeVector Y_dims({N, M});
TensorShape input_shape = X->Shape().Slice(2);
ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims));
Tensor* Y = context->Output(0, TensorShape(Y_dims));
TensorShape output_shape = Y->Shape().Slice(2);
TensorShapeVector Y_dims({N, M});
TensorShape input_shape = X->Shape().Slice(2);
ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims));
Tensor* Y = context->Output(0, TensorShape(Y_dims));
TensorShape output_shape = Y->Shape().Slice(2);

// Bail out early if one of the dimensions is zero.
if (Y->Shape().Size() == 0) {
return Status::OK();
}
// Bail out early if one of the dimensions is zero.
if (Y->Shape().Size() == 0) {
return Status::OK();
}

const int64_t input_image_size = input_shape.Size();
const int64_t output_image_size = output_shape.Size();
const int64_t kernel_size = TensorShape(kernel_shape).Size();
const int64_t X_offset = C / conv_attrs_.group * input_image_size;
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / conv_attrs_.group;
const int64_t W_offset = W->Shape().Size() / conv_attrs_.group;
const int64_t kernel_dim = C / conv_attrs_.group * kernel_size;
const int64_t col_buffer_size = kernel_dim * output_image_size;
const int64_t input_image_size = input_shape.Size();
const int64_t output_image_size = output_shape.Size();
const int64_t kernel_size = TensorShape(kernel_shape).Size();
const int64_t X_offset = C / conv_attrs_.group * input_image_size;
const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / conv_attrs_.group;
const int64_t W_offset = W->Shape().Size() / conv_attrs_.group;
const int64_t kernel_dim = C / conv_attrs_.group * kernel_size;
const int64_t col_buffer_size = kernel_dim * output_image_size;

const size_t kernel_rank = kernel_shape.size();
const size_t kernel_rank = kernel_shape.size();

BufferUniquePtr col_buffer;
BufferUniquePtr col_buffer;

// Pointwise convolutions can use the original input tensor in place,
// otherwise a temporary buffer is required for the im2col transform.
if (kernel_size != 1 || !conv_attrs_.HasStridesOneAndNoPadding()) {
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
// Pointwise convolutions can use the original input tensor in place,
// otherwise a temporary buffer is required for the im2col transform.
if (kernel_size != 1 || !conv_attrs_.HasStridesOneAndNoPadding()) {
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));

auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(uint8_t)) * col_buffer_size);
col_buffer = BufferUniquePtr(col_data, BufferDeleter(std::move(alloc)));
}
auto* col_data = alloc->Alloc(SafeInt<size_t>(sizeof(uint8_t)) * col_buffer_size);
col_buffer = BufferUniquePtr(col_data, BufferDeleter(std::move(alloc)));
}

auto* col_buffer_data = static_cast<uint8_t*>(col_buffer.get());

concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();

const auto* Xdata = X->Data<uint8_t>();
const auto* Wdata = W->Data<uint8_t>();
auto* Ydata = Y->MutableData<int32_t>();

for (int image_id = 0; image_id < N; ++image_id) {
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
if (col_buffer_data != nullptr) {
if (kernel_rank == 2) {
math::Im2col<uint8_t, StorageOrder::NCHW>()(
Xdata,
C / conv_attrs_.group,
input_shape[0],
input_shape[1],
kernel_shape[0],
kernel_shape[1],
dilations[0],
dilations[1],
pads[0],
pads[1],
pads[2],
pads[3],
strides[0],
strides[1],
col_buffer_data,
input_offset);
} else {
math::Im2col<uint8_t, StorageOrder::NCHW>()(
Xdata,
input_shape.GetDims().data(),
output_shape.GetDims().data(),
kernel_dim,
kernel_shape.data(),
strides.data(),
dilations.data(),
pads.data(),
static_cast<int>(kernel_rank),
col_buffer_data,
false,
input_offset);
auto* col_buffer_data = static_cast<uint8_t*>(col_buffer.get());

concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();

const auto* Xdata = static_cast<const uint8_t*>(X->DataRaw());
const auto* Wdata = static_cast<const uint8_t*>(W->DataRaw());
auto* Ydata = Y->MutableData<int32_t>();

for (int image_id = 0; image_id < N; ++image_id) {
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
if (col_buffer_data != nullptr) {
if (kernel_rank == 2) {
math::Im2col<XT, StorageOrder::NCHW>()(
reinterpret_cast<const XT*>(Xdata),
C / conv_attrs_.group,
input_shape[0],
input_shape[1],
kernel_shape[0],
kernel_shape[1],
dilations[0],
dilations[1],
pads[0],
pads[1],
pads[2],
pads[3],
strides[0],
strides[1],
reinterpret_cast<XT*>(col_buffer_data),
static_cast<XT>(input_offset));
} else {
math::Im2col<XT, StorageOrder::NCHW>()(
reinterpret_cast<const XT*>(Xdata),
input_shape.GetDims().data(),
output_shape.GetDims().data(),
kernel_dim,
kernel_shape.data(),
strides.data(),
dilations.data(),
pads.data(),
static_cast<int>(kernel_rank),
reinterpret_cast<XT*>(col_buffer_data),
false,
static_cast<XT>(input_offset));
}
}
}

MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
gemm_shape.M = static_cast<size_t>(M / conv_attrs_.group);
gemm_shape.N = static_cast<size_t>(output_image_size);
gemm_shape.K = static_cast<size_t>(kernel_dim);

MLAS_GEMM_QUANT_DATA_PARAMS gemm_params;
gemm_params.A = Wdata + group_id * W_offset;
gemm_params.lda = static_cast<size_t>(kernel_dim);
gemm_params.ZeroPointA = filter_offset;
gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data,
gemm_params.ldb = static_cast<size_t>(output_image_size);
gemm_params.ZeroPointB = &input_offset;
gemm_params.C = Ydata;
gemm_params.ldc = static_cast<size_t>(output_image_size);

MlasGemm(gemm_shape, gemm_params, thread_pool);

Xdata += X_offset;
Ydata += Y_offset;
MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
gemm_shape.M = static_cast<size_t>(M / conv_attrs_.group);
gemm_shape.N = static_cast<size_t>(output_image_size);
gemm_shape.K = static_cast<size_t>(kernel_dim);
gemm_shape.AIsSigned = W->IsDataType<int8_t>();
gemm_shape.BIsSigned = X->IsDataType<int8_t>();

MLAS_GEMM_QUANT_DATA_PARAMS gemm_params;
gemm_params.A = Wdata + group_id * W_offset;
gemm_params.lda = static_cast<size_t>(kernel_dim);
gemm_params.ZeroPointA = filter_offset;
gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data,
gemm_params.ldb = static_cast<size_t>(output_image_size);
gemm_params.ZeroPointB = &input_offset;
gemm_params.C = Ydata;
gemm_params.ldc = static_cast<size_t>(output_image_size);

MlasGemm(gemm_shape, gemm_params, thread_pool);

Xdata = reinterpret_cast<const uint8_t*>(X_offset + reinterpret_cast<const XT*>(Xdata));
Ydata += Y_offset;
}
}

return Status::OK();
}
};

return Status::OK();
ONNX_OPERATOR_KERNEL_EX(
ConvInteger,
kOnnxDomain,
10,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", {DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()})
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()})
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
ConvInteger);

Status ConvInteger::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
const auto* W = context->Input<Tensor>(1);
if (X->IsDataType<int8_t>()) {
if (W->IsDataType<int8_t>())
return ComputeInner<int8_t, int8_t>(context);
else
return ComputeInner<int8_t, uint8_t>(context);
} else {
if (W->IsDataType<int8_t>())
return ComputeInner<uint8_t, int8_t>(context);
else
return ComputeInner<uint8_t, uint8_t>(context);
}
}

} // namespace onnxruntime
1 change: 1 addition & 0 deletions onnxruntime/core/util/math_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ void Im2col<T, StorageOrder::NCHW>::operator()(

template struct Im2col<float, StorageOrder::NCHW>;
template struct Im2col<uint8_t, StorageOrder::NCHW>;
template struct Im2col<int8_t, StorageOrder::NCHW>;

template <typename T>
void Im2col<T, StorageOrder::NHWC>::operator()(
Expand Down
Loading