Skip to content

Sub-channel quantized type implementation #120172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions mlir/include/mlir-c/Dialect/Quant.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,47 @@ mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type);
MLIR_CAPI_EXPORTED bool
mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type);

//===---------------------------------------------------------------------===//
// UniformQuantizedSubChannelType
//===---------------------------------------------------------------------===//

/// Returns `true` if the given type is a UniformQuantizedSubChannel.
MLIR_CAPI_EXPORTED bool
mlirTypeIsAUniformQuantizedSubChannelType(MlirType type);

/// Creates a UniformQuantizedSubChannelType with the given parameters.
///
/// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be
/// DenseElementsAttrs. `quantizedDimensions` and `blockSizes`
/// point to `blockSizeInfoLength` number of elements, describing respectively
/// the quantization axis and corresponding block size.
MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedSubChannelTypeGet(
unsigned flags, MlirType storageType, MlirType expressedType,
MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr,
intptr_t blockSizeInfoLength, int32_t *quantizedDimensions,
int64_t *blockSizes, int64_t storageTypeMin, int64_t storageTypeMax);

/// Returns the number of block sizes provided in type.
MLIR_CAPI_EXPORTED intptr_t
mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type);

/// Returns the quantized dimension at the given position.
MLIR_CAPI_EXPORTED int32_t
mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type,
intptr_t pos);

/// Returns the block size at the given position.
MLIR_CAPI_EXPORTED int64_t
mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type, intptr_t pos);

/// Returns the scales of the quantized type.
MLIR_CAPI_EXPORTED MlirAttribute
mlirUniformQuantizedSubChannelTypeGetScales(MlirType type);

/// Returns the zero-points of the quantized type.
MLIR_CAPI_EXPORTED MlirAttribute
mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type);

//===---------------------------------------------------------------------===//
// CalibratedQuantizedType
//===---------------------------------------------------------------------===//
Expand Down
192 changes: 183 additions & 9 deletions mlir/include/mlir/Dialect/Quant/IR/QuantBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,17 @@ def Quant_Dialect : Dialect {
encodes the necessary information for (lossy) round-trip conversion between
an expressed and a stored value.

The `quant.uniform` type has two variants: per-layer quantization and
per-channel (or per-axis) quantization. In per-layer quantization, the
quantization information affects an entire tensor uniformly. Conversely, in
per-channel quantization, the data type encodes the specific tensor axis
that serves as the channel and includes quantization information for each
individual channel within the tensor. Below are the specific syntactic and
semantic considerations for each modality.
The `quant.uniform` type has three variants: per-layer quantization,
per-channel (or per-axis) quantization, and sub-channel (or blockwize)
quantization. In per-layer quantization, the quantization information
affects an entire tensor uniformly. Conversely, in per-channel
quantization, the data type encodes the specific tensor axis that serves
as the channel and includes quantization information for each individual
channel within the tensor. Sub-channel quantization is a generalization
of per-tensor and per-channel quantization, where the quantization
parameters are defined for blocks of elements along one or more
dimensions of the tensor. Below are the specific syntactic and semantic
considerations for each modality.


### Per-layer quantization
Expand Down Expand Up @@ -145,7 +149,7 @@ def Quant_Dialect : Dialect {
```
// A 2x3x4 tensor contains 8-bit signed integers representing 32-bit
// floats. Dimension 1 of the tensor acts as the channel dimension. Its
// size 3 matches the number of provided scale values. Tensor elemenets at
// size 3 matches the number of provided scale values. Tensor elements at
// positions [*][0][*], [*][1][*], and [*][2][*] use scales 3.0, 4.0, and
// 5.0, respectively.
tensor<2x3x4x!quant.uniform<i8:f32:1, {3.0, 4.0, 5.0}>>
Expand All @@ -159,6 +163,72 @@ def Quant_Dialect : Dialect {
tensor<?x?x!quant.uniform<u16:f32:0, {2.0:10, 3.0:20}>>
```

### Sub-channel quantization

Sub-channel quantization, also known as blockwise quantization, provides
finer-grained control than per-tensor or per-channel quantization. It
divides a tensor into blocks of elements, each with its own quantization
parameters (scale and zero point). This is particularly useful when
different regions of a tensor exhibit distinct value ranges.

The `!quant.uniform` type represents sub-channel quantization with the
following syntax:

```
`!quant.uniform` `<`
storedType (`<` storageMin `:` storageMax `>`)? `:`
expressedType `:` blockSizeInfo
scaleZeroTensor `>`

blockSizeInfo ::= `{` `}` | `{` axisBlock (`,` axisBlock)*)? `}`
axisBlock ::= axis `:` blockSize
scaleZeroTensor ::= scaleZeroDenseExp | scaleZeroList
scaleZeroDenseExp ::= `{` scaleZeroTensor (`,` scaleZeroTensor)* `}`
scaleZeroList ::= scaleZero (`,` scaleZero)*
scaleZero ::= scale (`:` zeroPoint)?

scaleZeroTensor ::= scale-zero-dense-exp | scale-zero-list
scale-zero-dense-exp ::= `{` scale-zero-tensor (`,` scale-zero-tensor)* `}`
scale-zero-list ::= scale (`:` zeroPoint)? (`,` scale (`:` zeroPoint)?)*
```

The `blockSize` field specifies the size of the blocks along dimension
`axis` of the tensor. The `scale` and `zeroPoint` fields specify the
quantization parameters for a particular block. Specifically, the tensor
element at position [i0...iN] uses
`scaleZeroTensor[i/blockSize0...i/blockSizeN].scale` and
`scaleZeroTensor[i/blockSize0...i/blockSizeN].zeroPoint` as scale
and zeroPoint respectively.

Here are some examples:

```
// A 3x4 tensor of i8 values representing f32 values, quantized
// along axis-0 and axis-1 with block sizes 1 and 2,
// respectively. As a result, the shape of the scales (or zero-points) will
// be `[3,4]/[1,2] = [3,2]`, which essentially represents the number of
// blocks along each axis. Tensor elements at positions
// [0][0] and [0][1] use scale `s00` and zero point `z00`,
// [0][2] and [0][3] use scale `s01` and zero point `z01`,
// [1][0] and [1][1] use scale `s10` and zero point `z10`,
// [1][2] and [1][3] use scale `s11` and zero point `z11`,
// [2][0] and [2][1] use scale `s20` and zero point `z20`,
// [2][2] and [2][3] use scale `s21` and zero point `z21`,
tensor<3x4x!quant.uniform<i8:f32:{0:1, 1:2},
{{s00:z00, s01:z01}, {s10:z10,s11:z11}, {s20:z20,s21:z21}}>>

// A 2D dynamically sized tensor contains u16 values
// representing f32 values. Since the shape of the quantization
// parameters (i.e. scales and zero-points) is given as [2,2] and
// the blocks-sizes are given as [1,2], the shape of the tensor is expected
// to be [2,4] (= [2,2] * [1,2]) at runtime. Tensor elements at positions
// [0][0] and [0][1] use scale `s00` and zero point `z00`,
// [0][2] and [0][3] use scale `s01` and zero point `z01`,
// [1][0] and [1][1] use scale `s10` and zero point `z10`,
// [1][2] and [1][3] use scale `s11` and zero point `z11`,
tensor<?x?x!quant.uniform<u16:f32:{0:1, 1:2},
{{s00:z00, s01:z01}, {s10:z10,s11:z11}}>>
```

## Per-axis quantization integrity

Expand All @@ -170,7 +240,7 @@ def Quant_Dialect : Dialect {
respected in any context in which the `!quant.uniform` data type is used,
such as the header of a `func.func` op, or the input of an arithmetic
operation.

- A quantized type with per-channel quantization information must be the
element type of a tensor container type, and may not occur directly as
the data type of a scalar value.
Expand Down Expand Up @@ -209,6 +279,110 @@ def Quant_Dialect : Dialect {
// Correct. The quantized type now includes 3 scale values, matching the
// size of dimension 1 of the result tensor.
%result = quant.qcast %input : tensor<?x3xf32> to tensor<?x3x!quant.uniform<i8:f32:1, {2.0, 3.0, 4.0}>>

## Sub-channel quantization integrity

When type `!quant.uniform` contains sub-channel quantization information,
the following rules are enforced. For efficiency, these rules are actively
enforced by the verifiers of `quant` dialect ops, but they must be
respected in any context in which the `!quant.uniform` data type is used,
such as the header of a `func.func` op, or the input of an arithmetic
operation.

- A quantized type with sub-channel quantization information must be the
element type of a tensor container type, and may not occur directly as
the data type of a scalar value.

```
// Incorrect. Type !quant.uniform specifies sub-channel quantization for a
// scalar type.
%result = quant.qcast %input : f32 to !quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>

// Correct. Type `!quant.uniform` with sub-channel quantization is wrapped
// in a `tensor` type.
%result = quant.qcast %input : tensor<2x2xf32> to
tensor<2x2x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
```

- The tensor containing the sub-channel quantized type must be ranked.

```
// Incorrect. Type !quant.uniform specifies sub-channel quantization for a
// unranked tensor type.
%result = quant.qcast %input : tensor<*xf32> to
tensor<*x!quant.uniform<i8:f32:{0:1, 1:2}, {{1.0}, {2.0}}>>
```

- The axis for which a block size is specified should be valid for a tensor
of a given rank. Block sizes can be specified for a subset of axes.
Any unspecified block size for an axis i defaults to the tensor dimension
size of that axis (shape(tensor)[i]).

```
// Incorrect. The block-size is specified for axis 2 which is greater than
// the rank of the tensor.
%result = quant.qcast %input : tensor<2x2xf32> to
tensor<2x2x!quant.uniform<i8:f32:{2:1, 1:2}, {{1.0}, {2.0}}>>

// Incorrect. The block-size is specified for a negative axis.
%result = quant.qcast %input : tensor<2x2xf32> to
tensor<2x2x!quant.uniform<i8:f32:{-1:1, 1:2}, {{1.0}, {2.0}}>>

// Correct. The block size for axis 1 is skipped which should be assumed as
// 2, the dim-size of tensor at axis 1.
%result = quant.qcast %input : tensor<6x2xf32> to
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {3.0}}>>

// Correct. The block size for all the axes are skipped making the
// sub-channel type essentially a per-tensor type.
%result = quant.qcast %input : tensor<6x2xf32> to
tensor<6x2x!quant.uniform<i8:f32:{}, {{1.0}}>>
```

- Block size for a particular axis should be a positive integer and should
be less than the dimension size of the tensor along that axis.

```
// Incorrect. The block size for axis 0 is -1.
%result = quant.qcast %input : tensor<6x2xf32> to
tensor<6x2x!quant.uniform<i8:f32:{0:-1}, {{1.0, 2.0}}>>

// Incorrect. The block size for axis 0 is 8 which is greater than the
// dimension size of tensor at axis 0 (which is 6).
%result = quant.qcast %input : tensor<6x2xf32> to
tensor<6x2x!quant.uniform<i8:f32:{0:8}, {{1.0, 2.0}}>>

// Correct. The block size for axis 0 is now 3.
%result = quant.qcast %input : tensor<6x2xf32> to
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
```

- shape(tensor) % blockSizes = 0 where blockSizes = [block sizes for
axis i in [0, 1, ..., rank(tensor)-1]].

```
// Incorrect. The block size for axis 0 is 4 and the corresponding
// dimension size is 6 and 6 % 4 != 0.
%result = quant.qcast %input : tensor<6x2xf32> to
tensor<6x2x!quant.uniform<i8:f32:{0:4}, {{1.0, 2.0}}>>

// Correct. The block size for axis 0 is now 3 making 6 % 3 = 0.
%result = quant.qcast %input : tensor<6x2xf32> to
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
```

- shape(scales) = shape(zeroPoints) = shape(tensor) / blockSizes.

```
// Incorrect. shape(tensor) = [6,2], blockSizes = [3,2], but
// shape(scales) is [1,2] which is not equal to [6,2]/[3,2].
%result = quant.qcast %input : tensor<6x2xf32> to
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0, 2.0}}>>

// Correct. shape(tensor) = [6,2], blockSizes = [3,2], and
// shape(scales) equals [6,2]/[3,2].
%result = quant.qcast %input : tensor<6x2xf32> to
tensor<6x2x!quant.uniform<i8:f32:{0:3}, {{1.0}, {2.0}}>>
```
}];
let cppNamespace = "::mlir::quant";
Expand Down
30 changes: 21 additions & 9 deletions mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef QUANT_BYTECODE
#define QUANT_BYTECODE

include "mlir/IR/BuiltinDialectBytecode.td"
include "mlir/IR/BytecodeBase.td"

def DoubleAPFloat:
Expand Down Expand Up @@ -81,20 +82,31 @@ def UniformQuantizedPerAxisType: DialectType<(type
}];
}

def UniformQuantizedSubChannelType
: DialectType<(type VarInt:$flags, Type:$storageType, Type:$expressedType,
SignedVarInt:$storageTypeMin, SignedVarInt:$storageTypeMax,
Array<SignedVarIntList>:$quantizedDimensions,
Array<SignedVarIntList>:$blockSizes, DenseElementsAttr:$scales,
DenseElementsAttr:$zeroPoints)> {
// Note: builder order differs from bytecode.
let cBuilder = [{
get<$_resultType>(context, flags, storageType, expressedType, scales,
zeroPoints, llvm::to_vector(llvm::map_range(quantizedDimensions,
[](int64_t dim) { return static_cast<int32_t>(dim);})), blockSizes,
storageTypeMin, storageTypeMax)
}];
}

/// This enum contains marker codes used to indicate which attribute is
/// currently being decoded, and how it should be decoded. The order of these
/// codes should generally be unchanged, as any changes will inevitably break
/// compatibility with older bytecode.

def QuantDialectTypes : DialectTypes<"Quant"> {
let elems = [
ReservedOrDead,
AnyQuantizedType,
AnyQuantizedTypeWithExpressedType,
CalibratedQuantizedType,
UniformQuantizedType,
UniformQuantizedPerAxisType
];
let elems = [ReservedOrDead, AnyQuantizedType,
AnyQuantizedTypeWithExpressedType, CalibratedQuantizedType,
UniformQuantizedType, UniformQuantizedPerAxisType,
UniformQuantizedSubChannelType];
}

#endif // QUANT_BYTECODE
#endif // QUANT_BYTECODE
Loading