diff --git a/src/gpu/intel/ocl/micro_sdpa.cl b/src/gpu/intel/ocl/micro_sdpa.cl index b3ba2c4e8f6..b56be86bc5b 100644 --- a/src/gpu/intel/ocl/micro_sdpa.cl +++ b/src/gpu/intel/ocl/micro_sdpa.cl @@ -21,11 +21,14 @@ #include "gemm_kq.h" #include "gemm_vs.h" -/* The quantization parameter may be unique for each token/element */ -#define QUANTIZE_2D 2 - /* The quantization parameter shares the same value across the work-group */ -#define QUANTIZE_COMMON 3 +#define QUANTIZE_COMMON 1 + +/* One quantization parameter for each token */ +#define QUANTIZE_1D 2 + +/* One quantization parameter for each group */ +#define QUANTIZE_2D 3 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define DIV_UP(x, y) (((x) + (y)-1) / (y)) @@ -93,6 +96,34 @@ DECLARE_2D_TILE_BLOCK_OPS(mask_tile_type, MSK_DATA_T, SUBGROUP_SIZE, mask_br, mask_bc, mask_nbr, mask_nbc) #endif +#if KEY_SCALES == QUANTIZE_1D +DECLARE_2D_TILE(k_scales_tile_type, KEY_ATTR_SCALES_DATA_T, SUBGROUP_SIZE, + ugemm_kq_sg_tile_m, 1, 1, 1) +DECLARE_2D_TILE(k_scales_tile_type_float, float, SUBGROUP_SIZE, + ugemm_kq_sg_tile_m, 1, 1, 1) +DECLARE_2D_TILE_BLOCK_OPS(k_scales_tile_type, KEY_ATTR_SCALES_DATA_T, + SUBGROUP_SIZE, ugemm_kq_sg_tile_m, 1, 1, 1) + +DECLARE_2D_TILE_HREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, k_scales_tile_type_float, SUBGROUP_SIZE, + ugemm_kq_sg_tile_m, 1, 1, 1) +#endif + +#if VAL_SCALES == QUANTIZE_1D +DECLARE_2D_TILE(v_scales_tile_type, KEY_ATTR_SCALES_DATA_T, SUBGROUP_SIZE, + ugemm_kq_sg_tile_m, 1, 1, 1) +DECLARE_2D_TILE(v_scales_tile_type_float, float, SUBGROUP_SIZE, + ugemm_kq_sg_tile_m, 1, 1, 1) +DECLARE_2D_TILE_BLOCK_OPS(v_scales_tile_type, KEY_ATTR_SCALES_DATA_T, + SUBGROUP_SIZE, ugemm_kq_sg_tile_m, 1, 1, 1) + +DECLARE_2D_TILE_HREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, v_scales_tile_type_float, SUBGROUP_SIZE, + ugemm_kq_sg_tile_m, 1, 1, 1) +#endif + #ifdef BLOCK_A DECLARE_2D_TILE_BLOCK_OPS(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, ugemm_vs_sg_tile_m, 1, 1, ugemm_vs_sg_tile_n) @@ -299,7 +330,19 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, /* sg_size */ SUBGROUP_SIZE, /* cache */ LSC_LDCC_L1C_L3C); -#if KEY_SCALES == QUANTIZE_2D +#if KEY_SCALES == QUANTIZE_1D + cooperative_prefetch_2d_maybe_rem( + /* ptr */ K_scales, + /* r */ k, + /* c */ 1, + /* rmax */ ugemm_kq_wg_tile_m, + /* cmax */ 1, + /* ld */ ldkq, + /* sg_id */ sg_ij, + /* n_sg */ sg_per_wg, + /* sg_size */ SUBGROUP_SIZE, + /* cache */ LSC_LDCC_L1C_L3C); +#elif KEY_SCALES == QUANTIZE_2D cooperative_prefetch_2d_maybe_rem( /* ptr */ K_scales, /* r */ k, @@ -414,6 +457,10 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, #if KEY_SCALES == QUANTIZE_COMMON #define k_scale_op(x) ((x)*k_scale) tile_elementwise(S_tile, k_scale_op); +#elif KEY_SCALES == QUANTIZE_1D + /* Load 1D K scales */ + k_scales_tile_type k_scales_tile; + tile_load_block(&k_scales_tile, K_scales, 0, k0 + sg_i0_kq, 0); #endif /* Apply attention mask */ @@ -436,6 +483,13 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, tile_hbroadcast_min(&S_tile, k_mask); #endif + /* Apply 1D K scales */ +#if KEY_SCALES == QUANTIZE_1D + k_scales_tile_type_float k_scales_tile_float; + tile_copy(k_scales_tile, k_scales_tile_float); + tile_hbroadcast_mul(&S_tile, k_scales_tile_float); +#endif + #if WITH_CAUSAL_MASK #define greater_than(offset_k, offset_q) (offset_k > offset_q) /* Apply causal mask */ @@ -468,7 +522,20 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, /* sg_size */ SUBGROUP_SIZE, /* cache */ LSC_LDCC_L1C_L3C); -#if VAL_SCALES == QUANTIZE_2D +#if VAL_SCALES == QUANTIZE_1D + /* Prefetch 1D V scales. */ + cooperative_prefetch_2d_maybe_rem( + /* ptr */ V_scales, + /* r */ 1, + /* c */ k - k0, + /* rmax */ 1, + /* cmax */ k_chunk, + /* ld */ ldvq, + /* sg_id */ sg_ij, + /* n_sg */ sg_per_wg, + /* sg_size */ SUBGROUP_SIZE, + /* cache */ LSC_LDCC_L1C_L3C); +#elif VAL_SCALES == QUANTIZE_2D /* Prefetch V scales. */ cooperative_prefetch_2d_maybe_rem( /* ptr */ V_scales, @@ -509,6 +576,12 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, #define scaled_exp(x) native_vexp2(x *scale) tile_elementwise(S_tile, scaled_exp); +#if VAL_SCALES == QUANTIZE_1D + /* Load 1D V scales */ + v_scales_tile_type v_scales_tile; + tile_load_block(&v_scales_tile, V_scales, 0, k0 + sg_i0_kq, 0); +#endif + #ifdef ALT_MAX /* Read back WG-wide maxima and adjust S to match */ intel_work_group_barrier_wait(CLK_LOCAL_MEM_FENCE); @@ -526,6 +599,13 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, tile_fill(S_sum_tile1, 0.0f); tile_vreduce_add(S_tile, &S_sum_tile1); +#if VAL_SCALES == QUANTIZE_1D + /* Apply 1D V scales to S tile */ + v_scales_tile_type_float v_scales_tile_float; + tile_copy(v_scales_tile, v_scales_tile_float); + tile_hbroadcast_mul(&S_tile, v_scales_tile_float); +#endif + /* Convert to half or bf16, VNNI format */ s_tile_type_packed S_tile_packed; tile_copy_to_vec2(S_tile, S_tile_packed, VEC_TYPE2); @@ -590,7 +670,19 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, /* n_sg */ sg_per_wg, /* sg_size */ SUBGROUP_SIZE, /* cache*/ LSC_LDCC_L1C_L3C); -#if KEY_SCALES == QUANTIZE_2D +#if KEY_SCALES == QUANTIZE_1D + cooperative_prefetch_2d_maybe_rem( + /* ptr */ K_scales + (k0 + ugemm_kq_wg_tile_m), + /* r */ k - k0 - ugemm_kq_wg_tile_m, + /* c */ 1, + /* rmax */ ugemm_kq_wg_tile_m, + /* cmax */ 1, + /* ld */ ldkq, + /* sg_id */ sg_ij, + /* n_sg */ sg_per_wg, + /* sg_size */ SUBGROUP_SIZE, + /* cache */ LSC_LDCC_L1C_L3C); +#elif KEY_SCALES == QUANTIZE_2D cooperative_prefetch_2d_maybe_rem( /* ptr */ K_scales + (k0 + ugemm_kq_wg_tile_m), /* r */ k - k0 - ugemm_kq_wg_tile_m, diff --git a/src/gpu/intel/ocl/micro_sdpa.cpp b/src/gpu/intel/ocl/micro_sdpa.cpp index 9fda6ed5fd1..a8a21dcb012 100644 --- a/src/gpu/intel/ocl/micro_sdpa.cpp +++ b/src/gpu/intel/ocl/micro_sdpa.cpp @@ -192,31 +192,6 @@ sdpa_config_t *choose_config_xehpc( return nullptr; } -/// Returns true if a common scale value is used for each slice of the tensor -/// operation. For 4D case it's when the mask's two first bits are on and two -/// last bits are off. -/// Examples: -/// | mask | result | -/// |-----------+---------| -/// | 0 (0000) | true | -/// | 12 (0011) | false | -/// | 3 (1100) | true | -/// | 1 (1000) | true | -/// | 8 (0001) | false | -bool with_quantize_common(const runtime_scales_t &scales) { - return !scales.has_default_values() - && (((scales.mask_ & 3) != 0 && (scales.mask_ & 12) == 0) - || scales.mask_ == 0); -} - -/// Returns true if a common zero points value is used for each slice of the -/// tensor operation -bool with_quantize_common(const zero_points_t &zp) { - int mask = zp.get(DNNL_ARG_WEIGHTS); - return !zp.has_default_values() - && (((mask & 3) != 0 && (mask & 12) == 0) || mask == 0); -} - } /* anonymous namespace */ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) { @@ -270,8 +245,11 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) { : MatrixLayout::N; }; - bool kq_common_scales = with_quantize_common(d->kq_scales); - bool kq_common_zp = with_quantize_common(d->kq_zero_points); + auto k_scales_q = key_scales_quant(); + auto k_zp_q = key_zp_quant(); + + bool k_scales_2d = (k_scales_q == pd_t::quantization_t::grouped); + bool k_zp_2d = (k_zp_q == pd_t::quantization_t::grouped); /* Set up GEMMProblem structure for first GEMM: K^T * Q */ GEMMProblem problem; @@ -293,7 +271,7 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) { opts_kq.localB = true; opts_kq.slmPtr = true; - if (with_key_scales() && !kq_common_scales) { + if (k_scales_2d) { auto scale_dt = key_scales_dt(); problem_kq.Ta_scale = jit::convert_dnnl_to_kernel_type(scale_dt); problem_kq.A_scale.setAlignment( @@ -307,16 +285,14 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) { problem_kq.AO.setAlignment( int8_t(d->keys() * types::data_type_size(zp_dt))); problem_kq.AO.layout = MatrixLayout::N; - problem_kq.aoPtrDims = kq_common_zp ? 0 : 2; + problem_kq.aoPtrDims = k_zp_2d ? 2 : 0; problem_kq.aOffset = ABOffset::Calc; } - - if (with_key_scales() || with_key_zp()) { + if (k_scales_2d || with_key_zp()) { problem_kq.aqGroupM = 1; - problem_kq.aqGroupK - = (kq_common_scales || kq_common_zp) ? 1 : key_group_size(); + problem_kq.aqGroupK = (k_scales_2d || k_zp_2d) ? key_group_size() : 1; } - opts_kq.scaleA = with_key_scales() && !kq_common_scales; + opts_kq.scaleA = k_scales_2d; opts_kq.offsetA = with_key_zp(); problem_kq.B.layout = MatrixLayout::Pr; @@ -354,8 +330,11 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) { ex.what()); } - bool vs_common_scales = with_quantize_common(d->vs_scales); - bool vs_common_zp = with_quantize_common(d->vs_zero_points); + auto v_scales_q = value_scales_quant(); + auto v_zp_q = value_zp_quant(); + + bool v_scales_2d = (v_scales_q == pd_t::quantization_t::grouped); + bool v_zp_2d = (v_zp_q == pd_t::quantization_t::grouped); /* Set up microkernel options */ micro::GEMMProtocol::Options opts_vs; @@ -366,7 +345,7 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) { auto problem_vs = problem; problem_vs.Ta_ext = jit::convert_dnnl_to_kernel_type(val_md()->data_type); problem_vs.A.layout = convert_dnnl_to_kernel_layout(val_md()); - if (with_value_scales() && !vs_common_scales) { + if (v_scales_2d) { auto scale_dt = value_scales_dt(); problem_vs.Ta_scale = jit::convert_dnnl_to_kernel_type(scale_dt); problem_vs.A_scale.setAlignment(uint8_t(d->head_size() @@ -381,16 +360,16 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) { problem_vs.AO.setAlignment(uint8_t(d->head_size() / value_group_size() * types::data_type_size(zp_dt))); problem_vs.AO.layout = MatrixLayout::N; - problem_vs.aoPtrDims = vs_common_zp ? 0 : 2; + problem_vs.aoPtrDims = v_zp_2d ? 2 : 0; problem_vs.aOffset = ABOffset::Calc; } - if (with_value_scales() || with_value_zp()) { - problem_vs.aqGroupM = (vs_common_scales || vs_common_zp) - ? 1 - : utils::rnd_up_pow2(value_group_size()); + if (v_scales_2d || with_value_zp()) { + problem_vs.aqGroupM = (v_scales_2d || v_zp_2d) + ? utils::rnd_up_pow2(value_group_size()) + : 1; problem_vs.aqGroupK = 1; } - opts_vs.scaleA = with_value_scales() && !vs_common_scales; + opts_vs.scaleA = v_scales_2d; opts_vs.offsetA = with_value_zp(); problem_vs.B.layout = MatrixLayout::Pr; @@ -489,21 +468,15 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) { kernel_ctx.define_int("TRANSPOSE_K", gemm_desc_t::get_trans(*pd()->key_md()) == dnnl_trans); - int kq_scale_mask = (static_cast(pd()->with_key_scales()) << 1) - | static_cast(with_quantize_common(d->kq_scales)); - kernel_ctx.define_int("KEY_SCALES", kq_scale_mask); - - int vs_scale_mask = (static_cast(pd()->with_value_scales()) << 1) - | static_cast(with_quantize_common(d->vs_scales)); - kernel_ctx.define_int("VAL_SCALES", vs_scale_mask); - - int kq_zp_mask = (static_cast(pd()->with_key_zp()) << 1) - | static_cast(with_quantize_common(d->kq_zero_points)); - kernel_ctx.define_int("KEY_ZERO_POINTS", kq_zp_mask); + kernel_ctx.define_int( + "KEY_SCALES", static_cast(pd()->key_scales_quant())); + kernel_ctx.define_int( + "VAL_SCALES", static_cast(pd()->value_scales_quant())); - int vs_zp_mask = (static_cast(pd()->with_value_zp()) << 1) - | static_cast(with_quantize_common(d->vs_zero_points)); - kernel_ctx.define_int("VAL_ZERO_POINTS", vs_zp_mask); + kernel_ctx.define_int( + "KEY_ZERO_POINTS", static_cast(pd()->key_zp_quant())); + kernel_ctx.define_int( + "VAL_ZERO_POINTS", static_cast(pd()->value_zp_quant())); using namespace data_type; auto elems_per_byte = [](data_type_t dt) { @@ -551,10 +524,11 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) { kernel_ctx.define_int("REMAINDER_K", !k_full); + if (ldmsk % 4 == 0) kernel_ctx.define_int("BLOCK_MSK", 1); + if (d_full) { if (ldq % 4 == 0) kernel_ctx.define_int("BLOCK_Q", 1); if (lda % 4 == 0 && v_full) kernel_ctx.define_int("BLOCK_A", 1); - if (ldmsk % 4 == 0) kernel_ctx.define_int("BLOCK_MSK", 1); kernel_ctx.define_int("REMAINDER_Q", (d->queries() % tile_q) != 0); } else if (pd()->arch() >= compute::gpu_arch_t::xe_hpc) { auto vbytes = d->values() * val_mdw.data_type_size(); diff --git a/src/gpu/intel/ocl/micro_sdpa.hpp b/src/gpu/intel/ocl/micro_sdpa.hpp index 72a9caa5d5a..f0c3d0435ad 100644 --- a/src/gpu/intel/ocl/micro_sdpa.hpp +++ b/src/gpu/intel/ocl/micro_sdpa.hpp @@ -47,6 +47,14 @@ struct micro_sdpa_t : public gpu_primitive_t { static constexpr int mask_q_index = 2; static constexpr int mask_k_index = 3; + enum class quantization_t { + /* Quantization types. Keep in sync with #defines in .cl file. */ + none = 0, + common = 1, + per_token = 2, + grouped = 3 + }; + DECLARE_COMMON_PD_T("ocl:micro:any", micro_sdpa_t); status_t init(impl::engine_t *engine) { @@ -100,8 +108,7 @@ struct micro_sdpa_t : public gpu_primitive_t { int kq_scales_mask = desc()->kq_scales.mask_; int kq_zp_mask = desc()->kq_zero_points.get(DNNL_ARG_WEIGHTS); - if (!desc()->kq_scales.has_default_values() - && !desc()->kq_zero_points.has_default_values()) + if (with_key_scales() && with_key_zp()) VDISPATCH_SDPA(kq_scales_mask == kq_zp_mask, "kq scales mask(%d) must equal kq zero point(%d) " "mask", @@ -127,8 +134,7 @@ struct micro_sdpa_t : public gpu_primitive_t { int vs_scales_mask = desc()->vs_scales.mask_; int vs_zp_mask = desc()->vs_zero_points.get(DNNL_ARG_WEIGHTS); - if (!desc()->vs_scales.has_default_values() - && !desc()->vs_zero_points.has_default_values()) + if (with_value_scales() && with_value_zp()) VDISPATCH_SDPA(vs_scales_mask == vs_zp_mask, "vs scales mask(%d) must equal vs zero point(%d) " "mask", @@ -152,8 +158,7 @@ struct micro_sdpa_t : public gpu_primitive_t { value_group_size()); } - if (!desc()->vs_scales.has_default_values() - || !desc()->vs_zero_points.has_default_values()) { + if (with_value_scales() || with_value_zp()) { int vgs = value_group_size(); VDISPATCH_SDPA( math::is_pow2(vgs) || vgs == val_md()->dims[3], @@ -200,6 +205,54 @@ struct micro_sdpa_t : public gpu_primitive_t { compute::gpu_arch_t arch() const { return arch_; } + quantization_t key_scales_quant() const { + if (!with_key_scales()) return quantization_t::none; + if (is_common_mask(desc()->kq_scales.mask_)) + return quantization_t::common; + if (!with_key_zp() && key_group_size() >= desc()->head_size()) + return quantization_t::per_token; + return quantization_t::grouped; + } + + quantization_t key_zp_quant() const { + if (!with_key_zp()) return quantization_t::none; + if (is_common_mask(desc()->kq_zero_points.get(DNNL_ARG_WEIGHTS))) + return quantization_t::common; + return quantization_t::grouped; + } + + quantization_t value_scales_quant() const { + if (!with_value_scales()) return quantization_t::none; + if (is_common_mask(desc()->vs_scales.mask_)) + return quantization_t::common; + // Functional, but slower than grouped. + // if (!with_value_zp() && value_group_size() >= desc()->head_size()) + // return quantization_t::per_token; + return quantization_t::grouped; + } + + quantization_t value_zp_quant() const { + if (!with_value_zp()) return quantization_t::none; + if (is_common_mask(desc()->vs_zero_points.get(DNNL_ARG_WEIGHTS))) + return quantization_t::common; + return quantization_t::grouped; + } + + /// Returns true if a common scale value is used for each slice of the tensor + /// operation. For 4D case it's when the mask's two first bits are on and two + /// last bits are off. + /// Examples: + /// | mask | result | + /// |-----------+---------| + /// | 0 (0000) | true | + /// | 12 (0011) | false | + /// | 3 (1100) | true | + /// | 1 (1000) | true | + /// | 8 (0001) | false | + static bool is_common_mask(unsigned mask) { + return ((mask & 3) != 0 && (mask & 12) == 0) || mask == 0; + } + private: micro::Package gemm_kq_, gemm_vs_; int sg_size_ = 0;