Skip to content

Commit

Permalink
perf: conv1d quantization (#1601)
Browse files Browse the repository at this point in the history
* perf: implement quantized conv1d

* perf: vectorize int8 quantization

* feat: implement quantizable Conv1DSpec

* fix: implement Vec::round fallback for armv7

* fix: handle quantization of conv weights while loading
  • Loading branch information
ebraraktas authored Mar 25, 2024
1 parent 5045b04 commit 8994330
Show file tree
Hide file tree
Showing 16 changed files with 228 additions and 75 deletions.
1 change: 1 addition & 0 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ namespace ctranslate2 {
const ops::Conv1D _conv_op;
const StorageView& _weight;
const StorageView* _bias;
const StorageView* _qscale;
};

}
Expand Down
17 changes: 11 additions & 6 deletions include/ctranslate2/ops/conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ namespace ctranslate2 {
void operator()(const StorageView& input,
const StorageView& weight,
const StorageView& bias,
StorageView& output) const;
StorageView& output,
const StorageView* qscale = nullptr) const;

void operator()(const StorageView& input,
const StorageView& weight,
StorageView& output) const;
StorageView& output,
const StorageView* qscale = nullptr) const;

private:
dim_t _stride;
Expand All @@ -27,17 +29,20 @@ namespace ctranslate2 {
void operator()(const StorageView& input,
const StorageView& weight,
const StorageView* bias,
StorageView& output) const;
StorageView& output,
const StorageView* qscale) const;

template <Device D, typename T>
void compute(const StorageView& input,
const StorageView& weight,
const StorageView* bias,
StorageView& output) const;
StorageView& output,
const StorageView* qscale = nullptr) const;

void compute_with_gemm(const StorageView& input, const StorageView& weight, StorageView& output) const;
void compute_with_gemm(const StorageView& input, const StorageView& weight, StorageView& output,
const StorageView* qscale) const;

void im2col(const StorageView& input, StorageView& output, dim_t kernel_size) const;
void im2col_transposed(const StorageView& input, StorageView& output, dim_t kernel_size) const;
};

}
Expand Down
1 change: 1 addition & 0 deletions python/ctranslate2/specs/common_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def has_bias(self):
class Conv1DSpec(model_spec.LayerSpec):
def __init__(self):
self.weight = None
self.weight_scale = model_spec.OPTIONAL
self.bias = model_spec.OPTIONAL


Expand Down
8 changes: 8 additions & 0 deletions python/ctranslate2/specs/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,20 @@ def _quantize(spec, name, value):
"int8_bfloat16",
):
value = value.to("float32").numpy()
# For conv1d layer we need to reshape to 2D before calculating scale
old_shape = None
if len(value.shape) == 3:
old_shape = value.shape
value = value.reshape(value.shape[0], -1)
amax = np.amax(np.absolute(value), axis=1)
amax[amax == 0] = 127.0
scale = 127.0 / amax
value *= np.expand_dims(scale, 1)
value = np.rint(value)
value = value.astype(np.int8)
# reshape back to old shape
if old_shape:
value = value.reshape(old_shape)
scale = NumpyVariable(scale)
value = NumpyVariable(value)
elif quantization in ("float16", "bfloat16", "float32"):
Expand Down
31 changes: 26 additions & 5 deletions src/cpu/kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,14 +572,35 @@ namespace ctranslate2 {

const auto amax = reduce_amax<TARGET_ISA>(x, depth);
const auto scale = (amax != 0.f ? int8_max / amax : 1.f);
using VecType = Vec<float, TARGET_ISA>;
const dim_t remaining = depth % VecType::width;
depth -= remaining;
auto vec_a_scale = VecType::load(scale);

if (shift_to_uint8) {
auto vec_int8_min = VecType::load(int8_min);
auto* dst = reinterpret_cast<uint8_t*>(y);
for (dim_t j = 0; j < depth; ++j)
dst[j] = round_func(x[j] * scale - int8_min);
for (dim_t j = 0; j < depth; j += VecType::width) {
auto v = VecType::load(x + j);
v = round_func(VecType::sub(VecType::mul(v, vec_a_scale), vec_int8_min));
VecType::convert_and_store(v, dst + j, VecType::width);
}
if (remaining) {
auto v = VecType::load(x + depth, remaining);
v = round_func(VecType::sub(VecType::mul(v, vec_a_scale), vec_int8_min));
VecType::convert_and_store(v, dst + depth, remaining);
}
} else {
for (dim_t j = 0; j < depth; ++j)
y[j] = round_func(x[j] * scale);
for (dim_t j = 0; j < depth; j += VecType::width) {
auto v = VecType::load(x + j);
v = round_func(VecType::mul(v, vec_a_scale));
VecType::convert_and_store(v, y + j, VecType::width);
}
if (remaining) {
auto v = VecType::load(x + depth, remaining);
v = round_func(VecType::mul(v, vec_a_scale));
VecType::convert_and_store(v, y + depth, remaining);
}
}

return scale;
Expand Down Expand Up @@ -612,7 +633,7 @@ namespace ctranslate2 {
bool shift_to_uint8,
bool round_before_cast) {
if (round_before_cast)
quantize_s8_batch(x, y, scales, batch_size, depth, shift_to_uint8, std::nearbyintf);
quantize_s8_batch(x, y, scales, batch_size, depth, shift_to_uint8, Vec<float, TARGET_ISA>::round);
else
quantize_s8_batch(x, y, scales, batch_size, depth, shift_to_uint8, identity());
}
Expand Down
8 changes: 8 additions & 0 deletions src/cpu/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ namespace ctranslate2 {
return a;
}

static inline float round(float a) {
return std::nearbyintf(a);
}

template<typename U>
static inline void convert_and_store(float v, U* a, dim_t count) {
*a = v;
}
};

template <typename T, CpuIsa ISA = CpuIsa::GENERIC>
Expand Down
11 changes: 11 additions & 0 deletions src/cpu/vec_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,17 @@ namespace ctranslate2 {
return reduce_m256(a, max);
}

static inline value_type round(value_type a) {
return _mm256_round_ps(a, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
}

template<typename T>
static void convert_and_store(value_type v, T* a, dim_t count) {
auto i32 = _mm256_cvttps_epi32(v);
int32_t tmp[8];
_mm256_storeu_si256(reinterpret_cast<__m256i *>(tmp), i32);
std::copy(tmp, tmp + count, a);
}
};

}
Expand Down
13 changes: 13 additions & 0 deletions src/cpu/vec_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,19 @@ namespace ctranslate2 {
return _mm512_reduce_max_ps(a);
}

static inline value_type round(value_type a) {
return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
}

static inline void convert_and_store(value_type v, int8_t* a, const dim_t count) {
auto i32 = _mm512_cvttps_epi32(v);
_mm512_mask_cvtsepi32_storeu_epi8(a, get_length_mask(count), i32);
}

static inline void convert_and_store(value_type v, uint8_t* a, const dim_t count) {
auto u32 = _mm512_cvttps_epu32(v);
_mm512_mask_cvtusepi32_storeu_epi8(a, get_length_mask(count), u32);
}
};

}
Expand Down
35 changes: 33 additions & 2 deletions src/cpu/vec_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,38 @@ namespace ctranslate2 {
return vmaxvq_f32(a);
}

};
static inline value_type round(value_type v) {
#ifdef __aarch64__
return vrndiq_f32(v);
#else
float temp[4] = {std::nearbyintf(v[0]), std::nearbyintf(v[1]), std::nearbyintf(v[2]), std::nearbyintf(v[3])};
return load(temp);
#endif
}

}
static inline void convert_and_store(value_type v, int8_t *a, dim_t count) {
//convert float32x4_t to int32x4_t
auto i32x4 = vcvtq_s32_f32(v);
//then convert to int16x4_t
auto i16x4 = vqmovn_s32(i32x4);
//finally convert to int8x4_t
auto i8x8 = vqmovn_s16(vcombine_s16(i16x4, vdup_n_s16(0)));
int8_t tmp[8];
vst1_s8(tmp, i8x8);
std::copy(tmp, tmp + count, a);
}

static inline void convert_and_store(value_type v, uint8_t *a, dim_t count) {
//convert float32x4_t to uint32x4_t
auto u32x4 = vcvtq_u32_f32(v);
//then convert to uint16x4_t
auto u16x4 = vqmovn_u32(u32x4);
//finally convert to uint8x8_t
auto u8x8 = vqmovn_u16(vcombine_u16(u16x4, vdup_n_u16(0)));
uint8_t tmp[8];
vst1_u8(tmp, u8x8);
std::copy(tmp, tmp + count, a);
}
};
}
}
7 changes: 4 additions & 3 deletions src/layers/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,8 @@ namespace ctranslate2 {
dim_t dilation)
: _conv_op(stride, padding, dilation)
, _weight(model.get_variable(scope + "/weight"))
, _bias(model.get_variable_if_exists(scope + "/bias")) {
, _bias(model.get_variable_if_exists(scope + "/bias"))
, _qscale(model.get_variable_if_exists(scope + "/weight_scale")) {
}

DataType Conv1D::output_type() const {
Expand All @@ -452,9 +453,9 @@ namespace ctranslate2 {

void Conv1D::operator()(const StorageView& input, StorageView& output) const {
if (_bias)
_conv_op(input, _weight, *_bias, output);
_conv_op(input, _weight, *_bias, output, _qscale);
else
_conv_op(input, _weight, output);
_conv_op(input, _weight, output, _qscale);
}

}
Expand Down
25 changes: 23 additions & 2 deletions src/models/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,29 @@ namespace ctranslate2 {

// Convert "weight" variables to the expected compute type.
// Other float variables (e.g. biases) may be converted to another float type.
if (is_quantizable(name))
ensure_dtype(name, variable, weight_dtype);
if (is_quantizable(name)) {
auto variable_weight_dtype = weight_dtype;
// For conv layer, we need to reshape to ensure dtype as its weights are 3D.
auto is_conv = name.find("conv") != std::string::npos;
auto kernel_size = -1;
if (is_conv) {
kernel_size = variable.dim(2);
variable.reshape({variable.dim(0), variable.dim(1) * variable.dim(2)});
// For CUDA and DNNL backend, quantized convolution is not supported. Hence, convert to float_dtype.
if (device == Device::CUDA
#ifdef CT2_WITH_DNNL
|| true
#endif
) {
variable_weight_dtype = float_dtype;
}
}
ensure_dtype(name, variable, variable_weight_dtype);
// Undo reshape for conv weights
if (is_conv) {
variable.reshape({variable.dim(0), variable.dim(1) / kernel_size, kernel_size});
}
}
else if (is_convertible(variable, name)
&& is_float_type(variable.dtype())
&& variable.dtype() != float_dtype)
Expand Down
3 changes: 1 addition & 2 deletions src/models/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ namespace ctranslate2 {
}

bool WhisperModel::is_quantizable(const std::string& variable_name) const {
return (Model::is_quantizable(variable_name)
&& variable_name.find("conv") == std::string::npos);
return Model::is_quantizable(variable_name);
}

bool WhisperModel::is_linear_weight(const std::string& variable_name) const {
Expand Down
22 changes: 9 additions & 13 deletions src/ops/conv1d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,24 @@ namespace ctranslate2 {
void Conv1D::operator()(const StorageView& input,
const StorageView& weight,
const StorageView& bias,
StorageView& output) const {
operator()(input, weight, &bias, output);
StorageView& output,
const StorageView* qscale) const {
operator()(input, weight, &bias, output, qscale);
}

void Conv1D::operator()(const StorageView& input,
const StorageView& weight,
StorageView& output) const {
operator()(input, weight, nullptr, output);
StorageView& output,
const StorageView* qscale) const {
operator()(input, weight, nullptr, output, qscale);
}

void Conv1D::operator()(const StorageView& input,
const StorageView& weight,
const StorageView* bias,
StorageView& output) const {
StorageView& output,
const StorageView* qscale) const {
PROFILE("Conv1D");

if (input.dtype() != weight.dtype())
throw std::invalid_argument("Conv1D: input dtype is "
+ dtype_name(input.dtype())
+ " but expected dtype "
+ dtype_name(weight.dtype()));

const dim_t batch_size = input.dim(0);
const dim_t input_length = input.dim(2);
const dim_t out_channels = weight.dim(0);
Expand All @@ -47,7 +43,7 @@ namespace ctranslate2 {
output.resize({batch_size, out_channels, output_length});

DEVICE_AND_FLOAT_DISPATCH("Conv1D", input.device(), input.dtype(),
(compute<D, T>(input, weight, bias, output)));
(compute<D, T>(input, weight, bias, output, qscale)));
}

}
Expand Down
Loading

0 comments on commit 8994330

Please sign in to comment.