Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xe: ocl: sdpa: pull grouped scales out of gemm where possible #2409

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 99 additions & 7 deletions src/gpu/intel/ocl/micro_sdpa.cl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
#include "gemm_kq.h"
#include "gemm_vs.h"

/* The quantization parameter may be unique for each token/element */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random spot: do we have test cases covering new per-token scaling optimized path and the old path?

#define QUANTIZE_2D 2

/* The quantization parameter shares the same value across the work-group */
#define QUANTIZE_COMMON 3
#define QUANTIZE_COMMON 1

/* One quantization parameter for each token */
#define QUANTIZE_1D 2

/* One quantization parameter for each group */
#define QUANTIZE_2D 3

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define DIV_UP(x, y) (((x) + (y)-1) / (y))
Expand Down Expand Up @@ -93,6 +96,34 @@ DECLARE_2D_TILE_BLOCK_OPS(mask_tile_type, MSK_DATA_T, SUBGROUP_SIZE, mask_br,
mask_bc, mask_nbr, mask_nbc)
#endif

#if KEY_SCALES == QUANTIZE_1D
DECLARE_2D_TILE(k_scales_tile_type, KEY_ATTR_SCALES_DATA_T, SUBGROUP_SIZE,
ugemm_kq_sg_tile_m, 1, 1, 1)
DECLARE_2D_TILE(k_scales_tile_type_float, float, SUBGROUP_SIZE,
ugemm_kq_sg_tile_m, 1, 1, 1)
DECLARE_2D_TILE_BLOCK_OPS(k_scales_tile_type, KEY_ATTR_SCALES_DATA_T,
SUBGROUP_SIZE, ugemm_kq_sg_tile_m, 1, 1, 1)

DECLARE_2D_TILE_HREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0,
ugemm_kq_c_type_nblock1, k_scales_tile_type_float, SUBGROUP_SIZE,
ugemm_kq_sg_tile_m, 1, 1, 1)
#endif

#if VAL_SCALES == QUANTIZE_1D
DECLARE_2D_TILE(v_scales_tile_type, KEY_ATTR_SCALES_DATA_T, SUBGROUP_SIZE,
ugemm_kq_sg_tile_m, 1, 1, 1)
DECLARE_2D_TILE(v_scales_tile_type_float, float, SUBGROUP_SIZE,
ugemm_kq_sg_tile_m, 1, 1, 1)
DECLARE_2D_TILE_BLOCK_OPS(v_scales_tile_type, KEY_ATTR_SCALES_DATA_T,
SUBGROUP_SIZE, ugemm_kq_sg_tile_m, 1, 1, 1)

DECLARE_2D_TILE_HREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0,
ugemm_kq_c_type_nblock1, v_scales_tile_type_float, SUBGROUP_SIZE,
ugemm_kq_sg_tile_m, 1, 1, 1)
Comment on lines +113 to +124
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#if VAL_SCALES == QUANTIZE_1D
DECLARE_2D_TILE(v_scales_tile_type, KEY_ATTR_SCALES_DATA_T, SUBGROUP_SIZE,
ugemm_kq_sg_tile_m, 1, 1, 1)
DECLARE_2D_TILE(v_scales_tile_type_float, float, SUBGROUP_SIZE,
ugemm_kq_sg_tile_m, 1, 1, 1)
DECLARE_2D_TILE_BLOCK_OPS(v_scales_tile_type, KEY_ATTR_SCALES_DATA_T,
SUBGROUP_SIZE, ugemm_kq_sg_tile_m, 1, 1, 1)
DECLARE_2D_TILE_HREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0,
ugemm_kq_c_type_nblock1, v_scales_tile_type_float, SUBGROUP_SIZE,
ugemm_kq_sg_tile_m, 1, 1, 1)
#if VAL_SCALES == QUANTIZE_1D
DECLARE_2D_TILE(v_scales_tile_type, VAL_ATTR_SCALES_DATA_T, SUBGROUP_SIZE,
ugemm_kq_sg_tile_m, 1, 1, 1)
DECLARE_2D_TILE(v_scales_tile_type_float, float, SUBGROUP_SIZE,
ugemm_kq_sg_tile_m, 1, 1, 1)
DECLARE_2D_TILE_BLOCK_OPS(v_scales_tile_type, VAL_ATTR_SCALES_DATA_T,
SUBGROUP_SIZE, ugemm_kq_sg_tile_m, 1, 1, 1)
DECLARE_2D_TILE_HREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0,
ugemm_kq_c_type_nblock1, v_scales_tile_type_float, SUBGROUP_SIZE,
ugemm_kq_sg_tile_m, 1, 1, 1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks.

#endif

#ifdef BLOCK_A
DECLARE_2D_TILE_BLOCK_OPS(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE,
ugemm_vs_sg_tile_m, 1, 1, ugemm_vs_sg_tile_n)
Expand Down Expand Up @@ -299,7 +330,19 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
/* sg_size */ SUBGROUP_SIZE,
/* cache */ LSC_LDCC_L1C_L3C);

#if KEY_SCALES == QUANTIZE_2D
#if KEY_SCALES == QUANTIZE_1D
cooperative_prefetch_2d_maybe_rem(
/* ptr */ K_scales,
/* r */ k,
/* c */ 1,
/* rmax */ ugemm_kq_wg_tile_m,
/* cmax */ 1,
/* ld */ ldkq,
/* sg_id */ sg_ij,
/* n_sg */ sg_per_wg,
/* sg_size */ SUBGROUP_SIZE,
/* cache */ LSC_LDCC_L1C_L3C);
#elif KEY_SCALES == QUANTIZE_2D
cooperative_prefetch_2d_maybe_rem(
/* ptr */ K_scales,
/* r */ k,
Expand Down Expand Up @@ -414,6 +457,10 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
#if KEY_SCALES == QUANTIZE_COMMON
#define k_scale_op(x) ((x)*k_scale)
tile_elementwise(S_tile, k_scale_op);
#elif KEY_SCALES == QUANTIZE_1D
/* Load 1D K scales */
k_scales_tile_type k_scales_tile;
tile_load_block(&k_scales_tile, K_scales, 0, k0 + sg_i0_kq, 0);
#endif

/* Apply attention mask */
Expand All @@ -436,6 +483,13 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
tile_hbroadcast_min(&S_tile, k_mask);
#endif

/* Apply 1D K scales */
#if KEY_SCALES == QUANTIZE_1D
k_scales_tile_type_float k_scales_tile_float;
tile_copy(k_scales_tile, k_scales_tile_float);
tile_hbroadcast_mul(&S_tile, k_scales_tile_float);
#endif

#if WITH_CAUSAL_MASK
#define greater_than(offset_k, offset_q) (offset_k > offset_q)
/* Apply causal mask */
Expand Down Expand Up @@ -468,7 +522,20 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
/* sg_size */ SUBGROUP_SIZE,
/* cache */ LSC_LDCC_L1C_L3C);

#if VAL_SCALES == QUANTIZE_2D
#if VAL_SCALES == QUANTIZE_1D
/* Prefetch 1D V scales. */
cooperative_prefetch_2d_maybe_rem(
/* ptr */ V_scales,
/* r */ 1,
/* c */ k - k0,
/* rmax */ 1,
/* cmax */ k_chunk,
/* ld */ ldvq,
/* sg_id */ sg_ij,
/* n_sg */ sg_per_wg,
/* sg_size */ SUBGROUP_SIZE,
/* cache */ LSC_LDCC_L1C_L3C);
#elif VAL_SCALES == QUANTIZE_2D
/* Prefetch V scales. */
cooperative_prefetch_2d_maybe_rem(
/* ptr */ V_scales,
Expand Down Expand Up @@ -509,6 +576,12 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
#define scaled_exp(x) native_vexp2(x *scale)
tile_elementwise(S_tile, scaled_exp);

#if VAL_SCALES == QUANTIZE_1D
/* Load 1D V scales */
v_scales_tile_type v_scales_tile;
tile_load_block(&v_scales_tile, V_scales, 0, k0 + sg_i0_kq, 0);
#endif

#ifdef ALT_MAX
/* Read back WG-wide maxima and adjust S to match */
intel_work_group_barrier_wait(CLK_LOCAL_MEM_FENCE);
Expand All @@ -526,6 +599,13 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
tile_fill(S_sum_tile1, 0.0f);
tile_vreduce_add(S_tile, &S_sum_tile1);

#if VAL_SCALES == QUANTIZE_1D
/* Apply 1D V scales to S tile */
v_scales_tile_type_float v_scales_tile_float;
tile_copy(v_scales_tile, v_scales_tile_float);
tile_hbroadcast_mul(&S_tile, v_scales_tile_float);
#endif

/* Convert to half or bf16, VNNI format */
s_tile_type_packed S_tile_packed;
tile_copy_to_vec2(S_tile, S_tile_packed, VEC_TYPE2);
Expand Down Expand Up @@ -590,7 +670,19 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
/* n_sg */ sg_per_wg,
/* sg_size */ SUBGROUP_SIZE,
/* cache*/ LSC_LDCC_L1C_L3C);
#if KEY_SCALES == QUANTIZE_2D
#if KEY_SCALES == QUANTIZE_1D
cooperative_prefetch_2d_maybe_rem(
/* ptr */ K_scales + (k0 + ugemm_kq_wg_tile_m),
/* r */ k - k0 - ugemm_kq_wg_tile_m,
/* c */ 1,
/* rmax */ ugemm_kq_wg_tile_m,
/* cmax */ 1,
/* ld */ ldkq,
/* sg_id */ sg_ij,
/* n_sg */ sg_per_wg,
/* sg_size */ SUBGROUP_SIZE,
/* cache */ LSC_LDCC_L1C_L3C);
#elif KEY_SCALES == QUANTIZE_2D
cooperative_prefetch_2d_maybe_rem(
/* ptr */ K_scales + (k0 + ugemm_kq_wg_tile_m),
/* r */ k - k0 - ugemm_kq_wg_tile_m,
Expand Down
90 changes: 32 additions & 58 deletions src/gpu/intel/ocl/micro_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,31 +192,6 @@ sdpa_config_t *choose_config_xehpc(
return nullptr;
}

/// Returns true if a common scale value is used for each slice of the tensor
/// operation. For 4D case it's when the mask's two first bits are on and two
/// last bits are off.
/// Examples:
/// | mask | result |
/// |-----------+---------|
/// | 0 (0000) | true |
/// | 12 (0011) | false |
/// | 3 (1100) | true |
/// | 1 (1000) | true |
/// | 8 (0001) | false |
bool with_quantize_common(const runtime_scales_t &scales) {
return !scales.has_default_values()
&& (((scales.mask_ & 3) != 0 && (scales.mask_ & 12) == 0)
|| scales.mask_ == 0);
}

/// Returns true if a common zero points value is used for each slice of the
/// tensor operation
bool with_quantize_common(const zero_points_t &zp) {
int mask = zp.get(DNNL_ARG_WEIGHTS);
return !zp.has_default_values()
&& (((mask & 3) != 0 && (mask & 12) == 0) || mask == 0);
}

} /* anonymous namespace */

status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
Expand Down Expand Up @@ -270,8 +245,11 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
: MatrixLayout::N;
};

bool kq_common_scales = with_quantize_common(d->kq_scales);
bool kq_common_zp = with_quantize_common(d->kq_zero_points);
auto k_scales_q = key_scales_quant();
auto k_zp_q = key_zp_quant();

bool k_scales_2d = (k_scales_q == pd_t::quantization_t::grouped);
bool k_zp_2d = (k_zp_q == pd_t::quantization_t::grouped);

/* Set up GEMMProblem structure for first GEMM: K^T * Q */
GEMMProblem problem;
Expand All @@ -293,7 +271,7 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
opts_kq.localB = true;
opts_kq.slmPtr = true;

if (with_key_scales() && !kq_common_scales) {
if (k_scales_2d) {
auto scale_dt = key_scales_dt();
problem_kq.Ta_scale = jit::convert_dnnl_to_kernel_type(scale_dt);
problem_kq.A_scale.setAlignment(
Expand All @@ -307,16 +285,14 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
problem_kq.AO.setAlignment(
int8_t(d->keys() * types::data_type_size(zp_dt)));
problem_kq.AO.layout = MatrixLayout::N;
problem_kq.aoPtrDims = kq_common_zp ? 0 : 2;
problem_kq.aoPtrDims = k_zp_2d ? 2 : 0;
problem_kq.aOffset = ABOffset::Calc;
}

if (with_key_scales() || with_key_zp()) {
if (k_scales_2d || with_key_zp()) {
problem_kq.aqGroupM = 1;
problem_kq.aqGroupK
= (kq_common_scales || kq_common_zp) ? 1 : key_group_size();
problem_kq.aqGroupK = (k_scales_2d || k_zp_2d) ? key_group_size() : 1;
}
opts_kq.scaleA = with_key_scales() && !kq_common_scales;
opts_kq.scaleA = k_scales_2d;
opts_kq.offsetA = with_key_zp();

problem_kq.B.layout = MatrixLayout::Pr;
Expand Down Expand Up @@ -354,8 +330,11 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
ex.what());
}

bool vs_common_scales = with_quantize_common(d->vs_scales);
bool vs_common_zp = with_quantize_common(d->vs_zero_points);
auto v_scales_q = value_scales_quant();
auto v_zp_q = value_zp_quant();

bool v_scales_2d = (v_scales_q == pd_t::quantization_t::grouped);
bool v_zp_2d = (v_zp_q == pd_t::quantization_t::grouped);

/* Set up microkernel options */
micro::GEMMProtocol::Options opts_vs;
Expand All @@ -366,7 +345,7 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
auto problem_vs = problem;
problem_vs.Ta_ext = jit::convert_dnnl_to_kernel_type(val_md()->data_type);
problem_vs.A.layout = convert_dnnl_to_kernel_layout(val_md());
if (with_value_scales() && !vs_common_scales) {
if (v_scales_2d) {
auto scale_dt = value_scales_dt();
problem_vs.Ta_scale = jit::convert_dnnl_to_kernel_type(scale_dt);
problem_vs.A_scale.setAlignment(uint8_t(d->head_size()
Expand All @@ -381,16 +360,16 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
problem_vs.AO.setAlignment(uint8_t(d->head_size() / value_group_size()
* types::data_type_size(zp_dt)));
problem_vs.AO.layout = MatrixLayout::N;
problem_vs.aoPtrDims = vs_common_zp ? 0 : 2;
problem_vs.aoPtrDims = v_zp_2d ? 2 : 0;
problem_vs.aOffset = ABOffset::Calc;
}
if (with_value_scales() || with_value_zp()) {
problem_vs.aqGroupM = (vs_common_scales || vs_common_zp)
? 1
: utils::rnd_up_pow2(value_group_size());
if (v_scales_2d || with_value_zp()) {
problem_vs.aqGroupM = (v_scales_2d || v_zp_2d)
? utils::rnd_up_pow2(value_group_size())
: 1;
problem_vs.aqGroupK = 1;
}
opts_vs.scaleA = with_value_scales() && !vs_common_scales;
opts_vs.scaleA = v_scales_2d;
opts_vs.offsetA = with_value_zp();

problem_vs.B.layout = MatrixLayout::Pr;
Expand Down Expand Up @@ -489,21 +468,15 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) {
kernel_ctx.define_int("TRANSPOSE_K",
gemm_desc_t::get_trans(*pd()->key_md()) == dnnl_trans);

int kq_scale_mask = (static_cast<int>(pd()->with_key_scales()) << 1)
| static_cast<int>(with_quantize_common(d->kq_scales));
kernel_ctx.define_int("KEY_SCALES", kq_scale_mask);

int vs_scale_mask = (static_cast<int>(pd()->with_value_scales()) << 1)
| static_cast<int>(with_quantize_common(d->vs_scales));
kernel_ctx.define_int("VAL_SCALES", vs_scale_mask);

int kq_zp_mask = (static_cast<int>(pd()->with_key_zp()) << 1)
| static_cast<int>(with_quantize_common(d->kq_zero_points));
kernel_ctx.define_int("KEY_ZERO_POINTS", kq_zp_mask);
kernel_ctx.define_int(
"KEY_SCALES", static_cast<int>(pd()->key_scales_quant()));
kernel_ctx.define_int(
"VAL_SCALES", static_cast<int>(pd()->value_scales_quant()));

int vs_zp_mask = (static_cast<int>(pd()->with_value_zp()) << 1)
| static_cast<int>(with_quantize_common(d->vs_zero_points));
kernel_ctx.define_int("VAL_ZERO_POINTS", vs_zp_mask);
kernel_ctx.define_int(
"KEY_ZERO_POINTS", static_cast<int>(pd()->key_zp_quant()));
kernel_ctx.define_int(
"VAL_ZERO_POINTS", static_cast<int>(pd()->value_zp_quant()));

using namespace data_type;
auto elems_per_byte = [](data_type_t dt) {
Expand Down Expand Up @@ -551,10 +524,11 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) {

kernel_ctx.define_int("REMAINDER_K", !k_full);

if (ldmsk % 4 == 0) kernel_ctx.define_int("BLOCK_MSK", 1);

if (d_full) {
if (ldq % 4 == 0) kernel_ctx.define_int("BLOCK_Q", 1);
if (lda % 4 == 0 && v_full) kernel_ctx.define_int("BLOCK_A", 1);
if (ldmsk % 4 == 0) kernel_ctx.define_int("BLOCK_MSK", 1);
kernel_ctx.define_int("REMAINDER_Q", (d->queries() % tile_q) != 0);
} else if (pd()->arch() >= compute::gpu_arch_t::xe_hpc) {
auto vbytes = d->values() * val_mdw.data_type_size();
Expand Down
Loading
Loading