diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h index 6c69275b8..7ea5e9126 100644 --- a/include/ctranslate2/layers/common.h +++ b/include/ctranslate2/layers/common.h @@ -180,6 +180,7 @@ namespace ctranslate2 { const ops::Conv1D _conv_op; const StorageView& _weight; const StorageView* _bias; + const StorageView* _qscale; }; } diff --git a/include/ctranslate2/ops/conv1d.h b/include/ctranslate2/ops/conv1d.h index 84b1ac85a..fc37021d0 100644 --- a/include/ctranslate2/ops/conv1d.h +++ b/include/ctranslate2/ops/conv1d.h @@ -13,11 +13,13 @@ namespace ctranslate2 { void operator()(const StorageView& input, const StorageView& weight, const StorageView& bias, - StorageView& output) const; + StorageView& output, + const StorageView* qscale = nullptr) const; void operator()(const StorageView& input, const StorageView& weight, - StorageView& output) const; + StorageView& output, + const StorageView* qscale = nullptr) const; private: dim_t _stride; @@ -27,17 +29,20 @@ namespace ctranslate2 { void operator()(const StorageView& input, const StorageView& weight, const StorageView* bias, - StorageView& output) const; + StorageView& output, + const StorageView* qscale) const; template void compute(const StorageView& input, const StorageView& weight, const StorageView* bias, - StorageView& output) const; + StorageView& output, + const StorageView* qscale = nullptr) const; - void compute_with_gemm(const StorageView& input, const StorageView& weight, StorageView& output) const; + void compute_with_gemm(const StorageView& input, const StorageView& weight, StorageView& output, + const StorageView* qscale) const; - void im2col(const StorageView& input, StorageView& output, dim_t kernel_size) const; + void im2col_transposed(const StorageView& input, StorageView& output, dim_t kernel_size) const; }; } diff --git a/python/ctranslate2/specs/common_spec.py b/python/ctranslate2/specs/common_spec.py index 9160867e0..b517ef77c 100644 --- a/python/ctranslate2/specs/common_spec.py +++ b/python/ctranslate2/specs/common_spec.py @@ -45,6 +45,7 @@ def has_bias(self): class Conv1DSpec(model_spec.LayerSpec): def __init__(self): self.weight = None + self.weight_scale = model_spec.OPTIONAL self.bias = model_spec.OPTIONAL diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 1db1a8369..c1e08f7c8 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -220,12 +220,20 @@ def _quantize(spec, name, value): "int8_bfloat16", ): value = value.to("float32").numpy() + # For conv1d layer we need to reshape to 2D before calculating scale + old_shape = None + if len(value.shape) == 3: + old_shape = value.shape + value = value.reshape(value.shape[0], -1) amax = np.amax(np.absolute(value), axis=1) amax[amax == 0] = 127.0 scale = 127.0 / amax value *= np.expand_dims(scale, 1) value = np.rint(value) value = value.astype(np.int8) + # reshape back to old shape + if old_shape: + value = value.reshape(old_shape) scale = NumpyVariable(scale) value = NumpyVariable(value) elif quantization in ("float16", "bfloat16", "float32"): diff --git a/src/cpu/kernels.cc b/src/cpu/kernels.cc index a61dfb351..2371704ec 100644 --- a/src/cpu/kernels.cc +++ b/src/cpu/kernels.cc @@ -572,14 +572,35 @@ namespace ctranslate2 { const auto amax = reduce_amax(x, depth); const auto scale = (amax != 0.f ? int8_max / amax : 1.f); + using VecType = Vec; + const dim_t remaining = depth % VecType::width; + depth -= remaining; + auto vec_a_scale = VecType::load(scale); if (shift_to_uint8) { + auto vec_int8_min = VecType::load(int8_min); auto* dst = reinterpret_cast(y); - for (dim_t j = 0; j < depth; ++j) - dst[j] = round_func(x[j] * scale - int8_min); + for (dim_t j = 0; j < depth; j += VecType::width) { + auto v = VecType::load(x + j); + v = round_func(VecType::sub(VecType::mul(v, vec_a_scale), vec_int8_min)); + VecType::convert_and_store(v, dst + j, VecType::width); + } + if (remaining) { + auto v = VecType::load(x + depth, remaining); + v = round_func(VecType::sub(VecType::mul(v, vec_a_scale), vec_int8_min)); + VecType::convert_and_store(v, dst + depth, remaining); + } } else { - for (dim_t j = 0; j < depth; ++j) - y[j] = round_func(x[j] * scale); + for (dim_t j = 0; j < depth; j += VecType::width) { + auto v = VecType::load(x + j); + v = round_func(VecType::mul(v, vec_a_scale)); + VecType::convert_and_store(v, y + j, VecType::width); + } + if (remaining) { + auto v = VecType::load(x + depth, remaining); + v = round_func(VecType::mul(v, vec_a_scale)); + VecType::convert_and_store(v, y + depth, remaining); + } } return scale; @@ -612,7 +633,7 @@ namespace ctranslate2 { bool shift_to_uint8, bool round_before_cast) { if (round_before_cast) - quantize_s8_batch(x, y, scales, batch_size, depth, shift_to_uint8, std::nearbyintf); + quantize_s8_batch(x, y, scales, batch_size, depth, shift_to_uint8, Vec::round); else quantize_s8_batch(x, y, scales, batch_size, depth, shift_to_uint8, identity()); } diff --git a/src/cpu/vec.h b/src/cpu/vec.h index eaece98ab..13128658c 100644 --- a/src/cpu/vec.h +++ b/src/cpu/vec.h @@ -140,6 +140,14 @@ namespace ctranslate2 { return a; } + static inline float round(float a) { + return std::nearbyintf(a); + } + + template + static inline void convert_and_store(float v, U* a, dim_t count) { + *a = v; + } }; template diff --git a/src/cpu/vec_avx.h b/src/cpu/vec_avx.h index e5505e1d5..762961c3e 100644 --- a/src/cpu/vec_avx.h +++ b/src/cpu/vec_avx.h @@ -189,6 +189,17 @@ namespace ctranslate2 { return reduce_m256(a, max); } + static inline value_type round(value_type a) { + return _mm256_round_ps(a, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC); + } + + template + static void convert_and_store(value_type v, T* a, dim_t count) { + auto i32 = _mm256_cvttps_epi32(v); + int32_t tmp[8]; + _mm256_storeu_si256(reinterpret_cast<__m256i *>(tmp), i32); + std::copy(tmp, tmp + count, a); + } }; } diff --git a/src/cpu/vec_avx512.h b/src/cpu/vec_avx512.h index c9ab94bb3..46f547815 100644 --- a/src/cpu/vec_avx512.h +++ b/src/cpu/vec_avx512.h @@ -143,6 +143,19 @@ namespace ctranslate2 { return _mm512_reduce_max_ps(a); } + static inline value_type round(value_type a) { + return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC); + } + + static inline void convert_and_store(value_type v, int8_t* a, const dim_t count) { + auto i32 = _mm512_cvttps_epi32(v); + _mm512_mask_cvtsepi32_storeu_epi8(a, get_length_mask(count), i32); + } + + static inline void convert_and_store(value_type v, uint8_t* a, const dim_t count) { + auto u32 = _mm512_cvttps_epu32(v); + _mm512_mask_cvtusepi32_storeu_epi8(a, get_length_mask(count), u32); + } }; } diff --git a/src/cpu/vec_neon.h b/src/cpu/vec_neon.h index 4ffb20773..4a407a5f1 100644 --- a/src/cpu/vec_neon.h +++ b/src/cpu/vec_neon.h @@ -159,7 +159,38 @@ namespace ctranslate2 { return vmaxvq_f32(a); } - }; + static inline value_type round(value_type v) { +#ifdef __aarch64__ + return vrndiq_f32(v); +#else + float temp[4] = {std::nearbyintf(v[0]), std::nearbyintf(v[1]), std::nearbyintf(v[2]), std::nearbyintf(v[3])}; + return load(temp); +#endif + } - } + static inline void convert_and_store(value_type v, int8_t *a, dim_t count) { + //convert float32x4_t to int32x4_t + auto i32x4 = vcvtq_s32_f32(v); + //then convert to int16x4_t + auto i16x4 = vqmovn_s32(i32x4); + //finally convert to int8x4_t + auto i8x8 = vqmovn_s16(vcombine_s16(i16x4, vdup_n_s16(0))); + int8_t tmp[8]; + vst1_s8(tmp, i8x8); + std::copy(tmp, tmp + count, a); + } + + static inline void convert_and_store(value_type v, uint8_t *a, dim_t count) { + //convert float32x4_t to uint32x4_t + auto u32x4 = vcvtq_u32_f32(v); + //then convert to uint16x4_t + auto u16x4 = vqmovn_u32(u32x4); + //finally convert to uint8x8_t + auto u8x8 = vqmovn_u16(vcombine_u16(u16x4, vdup_n_u16(0))); + uint8_t tmp[8]; + vst1_u8(tmp, u8x8); + std::copy(tmp, tmp + count, a); + } + }; +} } diff --git a/src/layers/common.cc b/src/layers/common.cc index 4430b369c..6f56c01ef 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -435,7 +435,8 @@ namespace ctranslate2 { dim_t dilation) : _conv_op(stride, padding, dilation) , _weight(model.get_variable(scope + "/weight")) - , _bias(model.get_variable_if_exists(scope + "/bias")) { + , _bias(model.get_variable_if_exists(scope + "/bias")) + , _qscale(model.get_variable_if_exists(scope + "/weight_scale")) { } DataType Conv1D::output_type() const { @@ -452,9 +453,9 @@ namespace ctranslate2 { void Conv1D::operator()(const StorageView& input, StorageView& output) const { if (_bias) - _conv_op(input, _weight, *_bias, output); + _conv_op(input, _weight, *_bias, output, _qscale); else - _conv_op(input, _weight, output); + _conv_op(input, _weight, output, _qscale); } } diff --git a/src/models/model.cc b/src/models/model.cc index 97bf3d1b5..8c08b07ab 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -198,8 +198,29 @@ namespace ctranslate2 { // Convert "weight" variables to the expected compute type. // Other float variables (e.g. biases) may be converted to another float type. - if (is_quantizable(name)) - ensure_dtype(name, variable, weight_dtype); + if (is_quantizable(name)) { + auto variable_weight_dtype = weight_dtype; + // For conv layer, we need to reshape to ensure dtype as its weights are 3D. + auto is_conv = name.find("conv") != std::string::npos; + auto kernel_size = -1; + if (is_conv) { + kernel_size = variable.dim(2); + variable.reshape({variable.dim(0), variable.dim(1) * variable.dim(2)}); + // For CUDA and DNNL backend, quantized convolution is not supported. Hence, convert to float_dtype. + if (device == Device::CUDA + #ifdef CT2_WITH_DNNL + || true + #endif + ) { + variable_weight_dtype = float_dtype; + } + } + ensure_dtype(name, variable, variable_weight_dtype); + // Undo reshape for conv weights + if (is_conv) { + variable.reshape({variable.dim(0), variable.dim(1) / kernel_size, kernel_size}); + } + } else if (is_convertible(variable, name) && is_float_type(variable.dtype()) && variable.dtype() != float_dtype) diff --git a/src/models/whisper.cc b/src/models/whisper.cc index da12898e9..349279240 100644 --- a/src/models/whisper.cc +++ b/src/models/whisper.cc @@ -34,8 +34,7 @@ namespace ctranslate2 { } bool WhisperModel::is_quantizable(const std::string& variable_name) const { - return (Model::is_quantizable(variable_name) - && variable_name.find("conv") == std::string::npos); + return Model::is_quantizable(variable_name); } bool WhisperModel::is_linear_weight(const std::string& variable_name) const { diff --git a/src/ops/conv1d.cc b/src/ops/conv1d.cc index 6cbe1b539..bde97dc8f 100644 --- a/src/ops/conv1d.cc +++ b/src/ops/conv1d.cc @@ -15,28 +15,24 @@ namespace ctranslate2 { void Conv1D::operator()(const StorageView& input, const StorageView& weight, const StorageView& bias, - StorageView& output) const { - operator()(input, weight, &bias, output); + StorageView& output, + const StorageView* qscale) const { + operator()(input, weight, &bias, output, qscale); } void Conv1D::operator()(const StorageView& input, const StorageView& weight, - StorageView& output) const { - operator()(input, weight, nullptr, output); + StorageView& output, + const StorageView* qscale) const { + operator()(input, weight, nullptr, output, qscale); } void Conv1D::operator()(const StorageView& input, const StorageView& weight, const StorageView* bias, - StorageView& output) const { + StorageView& output, + const StorageView* qscale) const { PROFILE("Conv1D"); - - if (input.dtype() != weight.dtype()) - throw std::invalid_argument("Conv1D: input dtype is " - + dtype_name(input.dtype()) - + " but expected dtype " - + dtype_name(weight.dtype())); - const dim_t batch_size = input.dim(0); const dim_t input_length = input.dim(2); const dim_t out_channels = weight.dim(0); @@ -47,7 +43,7 @@ namespace ctranslate2 { output.resize({batch_size, out_channels, output_length}); DEVICE_AND_FLOAT_DISPATCH("Conv1D", input.device(), input.dtype(), - (compute(input, weight, bias, output))); + (compute(input, weight, bias, output, qscale))); } } diff --git a/src/ops/conv1d_cpu.cc b/src/ops/conv1d_cpu.cc index 35a8488f7..a45388b57 100644 --- a/src/ops/conv1d_cpu.cc +++ b/src/ops/conv1d_cpu.cc @@ -10,7 +10,10 @@ namespace ctranslate2 { void Conv1D::compute(const StorageView& input, const StorageView& weight, const StorageView* bias, - StorageView& output) const { + StorageView& output, + const StorageView* qscale) const { + if (qscale) + throw std::runtime_error("Quantization is not supported in this Conv1D implementation"); dnnl::engine engine(dnnl::engine::kind::cpu, 0); dnnl::stream engine_stream(engine); @@ -113,31 +116,22 @@ namespace ctranslate2 { #else -# ifdef CT2_WITH_MKL -# include -# elif CT2_WITH_ACCELERATE -# include -# elif CT2_WITH_OPENBLAS -# include -# else -# define CT2_NO_BLAS -# endif - # include "ctranslate2/ops/gemm.h" # include "cpu/parallel.h" +# include "ctranslate2/ops/quantize.h" +# include "ctranslate2/ops/dequantize.h" namespace ctranslate2 { namespace ops { template<> - void Conv1D::compute(const StorageView& input, - const StorageView& weight, - const StorageView* bias, - StorageView& output) const { + void + Conv1D::compute(const StorageView &input, const StorageView &weight, const StorageView *bias, + StorageView &output, const StorageView *qscale) const { if (_dilation != 1) throw std::runtime_error("Dilation is not supported in this Conv1D implementation"); - compute_with_gemm(input, weight, output); + compute_with_gemm(input, weight, output, qscale); // Add bias if (bias) { // Need to broadcast along dims 0 and 2, because output shape is: @@ -159,67 +153,97 @@ namespace ctranslate2 { } } - void Conv1D::compute_with_gemm(const StorageView& input, - const StorageView& weight, - StorageView& output) const { + void Conv1D::compute_with_gemm(const StorageView &input, const StorageView &weight, StorageView &output, + const StorageView *qscale) const { const dim_t batch_size = input.dim(0); const dim_t in_channels = input.dim(1); const dim_t out_channels = weight.dim(0); const dim_t kernel_size = weight.dim(2); const dim_t output_length = output.dim(2); - std::vector im2col_output_shape{batch_size, in_channels * kernel_size, output_length}; - StorageView im2col_output(std::move(im2col_output_shape), static_cast(0.0), Device::CPU); - im2col(input, im2col_output, kernel_size); + // Create im2col_output tensor. + // im2col_output shape is (batch_size, out_length, in_channels * kernel_size). + // This is necessary for quantization: + // * we need to run GEMM as (weight x im2col_output) to avoid extra copies + // * input (RHS) must be quantized along columns, to dequantize later + // * but, Quantize op runs along rows. + // * Hence, we generate transposed im2col_output matrix and run gemm with transpose_b = true. + // * We can use qinput_scale generated from rows of this im2col_output, as they correspond + // to columns of the multiplied shape (because of transpose). + // we can use same matrix for FLOAT32 computation, too. + StorageView im2col_output({batch_size, output_length, in_channels * kernel_size}, 0.0f, weight.device()); + im2col_transposed(input, im2col_output, kernel_size); // Create a 2D view of weight to use in GEMM - const StorageView weight_view({weight.dim(0), in_channels * kernel_size}, const_cast(weight.data())); + StorageView weight_view(weight.dtype(), weight.device()); + weight_view.view(const_cast(weight.buffer()), {weight.dim(0), in_channels * kernel_size}); const dim_t m = out_channels; const dim_t n = output_length; - const dim_t k = im2col_output.dim(1); + const dim_t k = in_channels * kernel_size; const dim_t strideb = k * output_length; const dim_t stridec = out_channels * output_length; auto* b = im2col_output.data(); auto* c = output.data(); - const Gemm gemm(1.0, 0.0, false, false); + const Gemm gemm(1.0, 0.0, false, true); + const Quantize quantize_op(Quantize::ScaleType::PER_LAYER, + /*shift_to_uint8=*/false, + /*round_before_cast=*/true); + const Dequantize dequantize_op; + const auto device = im2col_output.device(); cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) { + StorageView qinput(weight.dtype(), device); + StorageView qinput_scale(device); + if (qscale) + qinput_scale.to(qscale->dtype()); + StorageView qoutput(DataType::INT32, device); for (dim_t i = begin; i < end; ++i) { float* b_i = b + (i * strideb); float* c_i = c + (i * stridec); + StorageView bb({n, k}, b_i); // transposed StorageView cc({m, n}, c_i); - StorageView bb({k, n}, b_i); - gemm(weight_view, bb, cc); + + if (qscale) { + quantize_op(bb, qinput, qinput_scale); + gemm(weight_view, qinput, qoutput); + dequantize_op(qoutput, + *qscale, + qinput_scale, + /*trans_a=*/false, + /*trans_b=*/true, + cc); + } else { + gemm(weight_view, bb, cc); + } } }); } - void Conv1D::im2col(const StorageView& input, StorageView& output, const dim_t kernel_size) const { + void Conv1D::im2col_transposed(const StorageView& input, StorageView& output, const dim_t kernel_size) const { // input: batch_size x in_channels x input_length - // output: batch_size x (in_channels * kernel_size) x output_length + // output: batch_size x output_length x (in_channels * kernel_size) const dim_t batch_size = input.dim(0); const dim_t in_channels = input.dim(1); const dim_t input_length = input.dim(2); auto* out = output.data (); const auto* in = input.data (); - dim_t input_channel_offset = 0; - dim_t out_index = 0; - for (int i = 0; i < batch_size; i++) { - for (int c = 0; c < in_channels; c++) { - // For each input channel fill (kernel_size * output_length) items in output array - for (int k = 0; k < kernel_size; k++) { - for (dim_t ti = -_padding; ti <= (input_length - kernel_size + _padding); ti += _stride) { + dim_t out_offset = 0; + const auto in_batch_stride = in_channels * input_length; + for (dim_t batch_offset = 0; batch_offset < batch_size * in_batch_stride; batch_offset += in_batch_stride) { + for (int ti = -_padding; ti <= (input_length - kernel_size + _padding); ti += _stride) { + for (dim_t c = batch_offset; c < (batch_offset + in_channels * input_length); c += input_length) { + for (int k = 0; k < kernel_size; k++) { // Fill items in [0, input_length) range - const auto window_i = k + ti; + auto window_i = k + ti; if (0 <= window_i && window_i < input_length) { - out[out_index] = in[window_i + input_channel_offset]; + out[out_offset] = in[window_i + c]; } - out_index += 1; + out_offset += 1; } } - input_channel_offset += input_length; } } } + } } diff --git a/src/ops/conv1d_gpu.cu b/src/ops/conv1d_gpu.cu index 3b389358c..6f4d10b39 100644 --- a/src/ops/conv1d_gpu.cu +++ b/src/ops/conv1d_gpu.cu @@ -9,7 +9,11 @@ namespace ctranslate2 { void Conv1D::compute(const StorageView& input, const StorageView& weight, const StorageView* bias, - StorageView& output) const { + StorageView& output, + const StorageView* qscale) const { + if (qscale) + throw std::runtime_error("Quantization is not supported in this Conv1D implementation"); + #ifndef CT2_WITH_CUDNN (void)input; (void)weight; @@ -144,7 +148,8 @@ namespace ctranslate2 { Conv1D::compute(const StorageView& input, \ const StorageView& weight, \ const StorageView* bias, \ - StorageView& output) const; + StorageView& output, \ + const StorageView* qscale) const; DECLARE_IMPL(float) DECLARE_IMPL(float16_t) diff --git a/tests/ops_test.cc b/tests/ops_test.cc index 1ceae1dfd..e0615c052 100644 --- a/tests/ops_test.cc +++ b/tests/ops_test.cc @@ -774,6 +774,14 @@ TEST_P(OpDeviceTest, QuantizeINT8) { expect_storage_eq(qa, expected_qa); } + // With rounding before cast and shift to uint8. + { + StorageView expected_qa(a.shape(), std::vector{1, 90, -64, -103, -98, -1, 110, -128}); + ops::Quantize(ops::Quantize::ScaleType::GLOBAL, true, true)(a, qa, scale); + expect_storage_eq(scale, expected_scale); + expect_storage_eq(qa, expected_qa); + } + // Without rounding before cast (legacy behavior). { StorageView expected_qa(a.shape(), std::vector{-127, -38, 63, 25, 30, 127, -18, 0});