Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Do not modify directly.*
|Conv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|22+|**T** = tensor(float)|
|||[11, 21]|**T** = tensor(float)|
|||[1, 10]|**T** = tensor(float)|
|ConvInteger|*in* x:**T1**<br> *in* w:**T2**<br> *in* x_zero_point:**T1**<br> *in* w_zero_point:**T2**<br> *out* y:**T3**|10+|**T1** = tensor(uint8)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int32)|
|ConvInteger|*in* x:**T1**<br> *in* w:**T2**<br> *in* x_zero_point:**T1**<br> *in* w_zero_point:**T2**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int32)|
|ConvTranspose|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|22+|**T** = tensor(float)|
|||[11, 21]|**T** = tensor(float)|
|||[1, 10]|**T** = tensor(float)|
Expand Down
115 changes: 78 additions & 37 deletions onnxruntime/core/providers/cpu/quantization/conv_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ ONNX_OPERATOR_KERNEL_EX(
10,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T1", {DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()})
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()})
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
ConvInteger);

Expand All @@ -43,12 +45,12 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
if (num_inputs >= 3 && input_defs[2]->Exists()) {
const auto* X_Zero_Point = context->Input<Tensor>(2);
ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1.");
input_offset = *(X_Zero_Point->Data<uint8_t>());
input_offset = *static_cast<const uint8_t*>(X_Zero_Point->DataRaw());
}
if (num_inputs >= 4 && input_defs[3]->Exists()) {
const auto* W_Zero_Point = context->Input<Tensor>(3);
ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now.");
filter_offset = *(W_Zero_Point->Data<uint8_t>());
filter_offset = *static_cast<const uint8_t*>(W_Zero_Point->DataRaw());
}

const int64_t N = X->Shape()[0];
Expand Down Expand Up @@ -110,58 +112,97 @@ Status ConvInteger::Compute(OpKernelContext* context) const {

concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();

const auto* Xdata = X->Data<uint8_t>();
const auto* Wdata = W->Data<uint8_t>();
const auto* Xdata = static_cast<const uint8_t*>(X->DataRaw());
const auto* Wdata = static_cast<const uint8_t*>(W->DataRaw());
bool X_is_signed = X->IsDataType<int8_t>();
auto* Ydata = Y->MutableData<int32_t>();

for (int image_id = 0; image_id < N; ++image_id) {
for (int group_id = 0; group_id < conv_attrs_.group; ++group_id) {
if (col_buffer_data != nullptr) {
if (kernel_rank == 2) {
math::Im2col<uint8_t, StorageOrder::NCHW>()(
Xdata,
C / conv_attrs_.group,
input_shape[0],
input_shape[1],
kernel_shape[0],
kernel_shape[1],
dilations[0],
dilations[1],
pads[0],
pads[1],
pads[2],
pads[3],
strides[0],
strides[1],
col_buffer_data,
input_offset);
if (X_is_signed) {
math::Im2col<int8_t, StorageOrder::NCHW>()(
reinterpret_cast<const int8_t*>(Xdata),
C / conv_attrs_.group,
input_shape[0],
input_shape[1],
kernel_shape[0],
kernel_shape[1],
dilations[0],
dilations[1],
pads[0],
pads[1],
pads[2],
pads[3],
strides[0],
strides[1],
reinterpret_cast<int8_t*>(col_buffer_data),
static_cast<int8_t>(input_offset));
} else {
math::Im2col<uint8_t, StorageOrder::NCHW>()(
Xdata,
C / conv_attrs_.group,
input_shape[0],
input_shape[1],
kernel_shape[0],
kernel_shape[1],
dilations[0],
dilations[1],
pads[0],
pads[1],
pads[2],
pads[3],
strides[0],
strides[1],
col_buffer_data,
input_offset);
}
} else {
math::Im2col<uint8_t, StorageOrder::NCHW>()(
Xdata,
input_shape.GetDims().data(),
output_shape.GetDims().data(),
kernel_dim,
kernel_shape.data(),
strides.data(),
dilations.data(),
pads.data(),
static_cast<int>(kernel_rank),
col_buffer_data,
false,
input_offset);
if (X_is_signed) {
math::Im2col<int8_t, StorageOrder::NCHW>()(
reinterpret_cast<const int8_t*>(Xdata),
input_shape.GetDims().data(),
output_shape.GetDims().data(),
kernel_dim,
kernel_shape.data(),
strides.data(),
dilations.data(),
pads.data(),
static_cast<int>(kernel_rank),
reinterpret_cast<int8_t*>(col_buffer_data),
false,
static_cast<int8_t>(input_offset));
} else {
math::Im2col<uint8_t, StorageOrder::NCHW>()(
Xdata,
input_shape.GetDims().data(),
output_shape.GetDims().data(),
kernel_dim,
kernel_shape.data(),
strides.data(),
dilations.data(),
pads.data(),
static_cast<int>(kernel_rank),
col_buffer_data,
false,
input_offset);
}
}
}

MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
gemm_shape.M = static_cast<size_t>(M / conv_attrs_.group);
gemm_shape.N = static_cast<size_t>(output_image_size);
gemm_shape.K = static_cast<size_t>(kernel_dim);
gemm_shape.AIsSigned = W->IsDataType<int8_t>();
gemm_shape.BIsSigned = X_is_signed;

MLAS_GEMM_QUANT_DATA_PARAMS gemm_params;
gemm_params.A = Wdata + group_id * W_offset;
gemm_params.lda = static_cast<size_t>(kernel_dim);
gemm_params.ZeroPointA = filter_offset;
gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data,
gemm_params.B = (col_buffer_data == nullptr) ? Xdata : col_buffer_data;
gemm_params.ldb = static_cast<size_t>(output_image_size);
gemm_params.ZeroPointB = &input_offset;
gemm_params.C = Ydata;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/util/math_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ void Im2col<T, StorageOrder::NCHW>::operator()(

template struct Im2col<float, StorageOrder::NCHW>;
template struct Im2col<uint8_t, StorageOrder::NCHW>;
template struct Im2col<int8_t, StorageOrder::NCHW>;

template <typename T>
void Im2col<T, StorageOrder::NHWC>::operator()(
Expand Down
Loading
Loading