Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const

const int cc = ggml_cuda_info().devices[device].cc;

// TODO: temporary until support is extended
// https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206
if (K->ne[1] % FATTN_KQ_STRIDE != 0) {
return BEST_FATTN_KERNEL_NONE;
}

switch (K->ne[0]) {
case 64:
case 128:
Expand Down
60 changes: 58 additions & 2 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,13 +918,58 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
int32_t ncpsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_UNUSED(op);

char base[256];
char name[256];

snprintf(base, 256, "kernel_%s",
"flash_attn_ext_pad");

snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
base,
has_mask,
ncpsg);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);

//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

ggml_metal_cv_free(cv);

return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_library_t lib,
const ggml_tensor * op,
bool has_mask,
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);

Expand All @@ -937,18 +982,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];

// do bounds checks for the mask?
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);

snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
"flash_attn_ext",
ggml_type_name(op->src[1]->type),
dk,
dv);

snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
base,
has_mask,
has_sinks,
has_bias,
has_scap,
has_kvpad,
bc_mask,
ns10,
ns20,
nsg);
Expand All @@ -964,6 +1014,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);

ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);

ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
Expand All @@ -983,6 +1036,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg,
int32_t nwg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
Expand All @@ -1002,12 +1056,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
dk,
dv);

snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
base,
has_mask,
has_sinks,
has_bias,
has_scap,
has_kvpad,
ns10,
ns20,
nsg, nwg);
Expand All @@ -1023,6 +1078,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);

ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
int32_t ncpsg);

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg);

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
Expand All @@ -151,6 +158,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg,
int32_t nwg);

Expand Down
31 changes: 26 additions & 5 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@
#define N_SG_IQ4_XS 2

// function constants offsets
#define FC_FLASH_ATTN_EXT 100
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
#define FC_MUL_MV 400
#define FC_MUL_MM 500
#define FC_FLASH_ATTN_EXT_PAD 100
#define FC_FLASH_ATTN_EXT 200
#define FC_FLASH_ATTN_EXT_VEC 300
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400
#define FC_MUL_MV 500
#define FC_MUL_MM 600

// kernel argument structs
//
Expand Down Expand Up @@ -243,6 +244,24 @@ typedef struct {
int32_t sect_3;
} ggml_metal_kargs_rope;

typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
} ggml_metal_kargs_flash_attn_ext_pad;

typedef struct {
int32_t ne01;
int32_t ne02;
Expand All @@ -261,6 +280,7 @@ typedef struct {
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
Expand Down Expand Up @@ -295,6 +315,7 @@ typedef struct {
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
Expand Down
Loading
Loading