diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index 13ecced8285..182114c2697 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -16,8 +16,8 @@ if (WHISPER_SDL2) llama-hparams.cpp llama-impl.cpp llama-io.cpp - llama-kv-cache-unified.cpp - llama-kv-cache-unified-iswa.cpp + llama-kv-cache.cpp + llama-kv-cache-iswa.cpp llama-memory-recurrent.cpp llama-memory-hybrid.cpp llama-memory.cpp diff --git a/examples/talk-llama/llama-adapter.cpp b/examples/talk-llama/llama-adapter.cpp index 8d94034aed9..d8eef75a7ad 100644 --- a/examples/talk-llama/llama-adapter.cpp +++ b/examples/talk-llama/llama-adapter.cpp @@ -6,6 +6,7 @@ #include #include +#include #include // vec @@ -163,13 +164,38 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ // check metadata { + const gguf_context * gguf_ctx = ctx_gguf.get(); + + LLAMA_LOG_INFO("%s: Dumping metadata keys/values.\n", __func__); + + // get metadata as string + for (int i = 0; i < gguf_get_n_kv(gguf_ctx); i++) { + gguf_type type = gguf_get_kv_type(gguf_ctx, i); + const std::string type_name = + type == GGUF_TYPE_ARRAY + ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(gguf_ctx, i)), gguf_get_arr_n(gguf_ctx, i)) + : gguf_type_name(type); + const char * name = gguf_get_key(gguf_ctx, i); + const std::string value = gguf_kv_to_str(gguf_ctx, i); + + if (type != GGUF_TYPE_ARRAY) { + adapter.gguf_kv.emplace(name, value); + } + + const size_t MAX_VALUE_LEN = 40; + std::string print_value = value.size() > MAX_VALUE_LEN ? format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()) : value; + replace_all(print_value, "\n", "\\n"); + + LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), print_value.c_str()); + } + auto get_kv_str = [&](const std::string & key) -> std::string { - int id = gguf_find_key(ctx_gguf.get(), key.c_str()); - return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id)); + int id = gguf_find_key(gguf_ctx, key.c_str()); + return id < 0 ? "" : std::string(gguf_get_val_str(gguf_ctx, id)); }; auto get_kv_f32 = [&](const std::string & key) -> float { - int id = gguf_find_key(ctx_gguf.get(), key.c_str()); - return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id); + int id = gguf_find_key(gguf_ctx, key.c_str()); + return id < 0 ? 0.0f : gguf_get_val_f32(gguf_ctx, id); }; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); @@ -190,6 +216,26 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA)); + + // parse alora invocation sequence vector + const auto & key = llm_kv(LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS); + const int kid = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (kid >= 0) { + if (gguf_get_kv_type(ctx_gguf.get(), kid) != GGUF_TYPE_ARRAY) { + throw std::runtime_error("invalid gguf type for " + key); + } + const auto arr_type = gguf_get_arr_type(ctx_gguf.get(), kid); + if (arr_type != GGUF_TYPE_UINT32) { + throw std::runtime_error("invalid gguf element type for " + key); + } + const size_t seq_len = gguf_get_arr_n(ctx_gguf.get(), kid); + const void * data = gguf_get_arr_data(ctx_gguf.get(), kid); + adapter.alora_invocation_tokens.resize(seq_len); + std::copy( + (const llama_token *)data, + (const llama_token *)data + seq_len, + adapter.alora_invocation_tokens.begin()); + } } int n_tensors = gguf_get_n_tensors(ctx_gguf.get()); @@ -383,6 +429,57 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p return nullptr; } +int32_t llama_adapter_meta_val_str(const llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size) { + const auto & it = adapter->gguf_kv.find(key); + if (it == adapter->gguf_kv.end()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + +int32_t llama_adapter_meta_count(const llama_adapter_lora * adapter) { + return (int)adapter->gguf_kv.size(); +} + +int32_t llama_adapter_meta_key_by_index(const llama_adapter_lora * adapter, int i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)adapter->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = adapter->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->first.c_str()); +} + +int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)adapter->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = adapter->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + void llama_adapter_lora_free(llama_adapter_lora * adapter) { delete adapter; } + +uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) { + if (!adapter) { + return 0; + } + return adapter->alora_invocation_tokens.size(); +} + +const llama_token * llama_adapter_get_alora_invocation_tokens(const llama_adapter_lora * adapter) { + GGML_ASSERT(adapter); + return adapter->alora_invocation_tokens.data(); +} diff --git a/examples/talk-llama/llama-adapter.h b/examples/talk-llama/llama-adapter.h index 65824e97276..4f65247c0fe 100644 --- a/examples/talk-llama/llama-adapter.h +++ b/examples/talk-llama/llama-adapter.h @@ -67,6 +67,12 @@ struct llama_adapter_lora { float alpha; + // gguf metadata + std::unordered_map gguf_kv; + + // activated lora (aLoRA) + std::vector alora_invocation_tokens; + llama_adapter_lora() = default; ~llama_adapter_lora() = default; diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 18dcc6ddfe5..a4d2973ada5 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -22,6 +22,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, { LLM_ARCH_NEO_BERT, "neo-bert" }, { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, @@ -44,6 +45,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, + { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA2, "mamba2" }, @@ -68,6 +70,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE4, "exaone4" }, { LLM_ARCH_RWKV6, "rwkv6" }, @@ -93,6 +96,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DREAM, "dream" }, { LLM_ARCH_SMALLTHINKER, "smallthinker" }, { LLM_ARCH_LLADA, "llada" }, + { LLM_ARCH_LLADA_MOE, "llada-moe" }, + { LLM_ARCH_SEED_OSS, "seed_oss" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -133,7 +138,9 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, + { LLM_KV_DECODER_BLOCK_COUNT, "%s.decoder_block_count" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, + { LLM_KV_ROUTER_LOGIT_SOFTCAPPING, "%s.router_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, { LLM_KV_SWIN_NORM, "%s.swin_norm" }, { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" }, @@ -164,19 +171,25 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, + { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, - { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, - { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, - { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, - { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, - { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, - { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, + { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -233,8 +246,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, - { LLM_KV_ADAPTER_TYPE, "adapter.type" }, - { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" }, + { LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" }, + { LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" }, // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, @@ -390,12 +406,16 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, }, @@ -574,6 +594,20 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_CLS, "cls" }, }, }, + { + LLM_ARCH_JINA_BERT_V3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + }, + }, { LLM_ARCH_BLOOM, { @@ -1019,6 +1053,27 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" }, }, }, + { + LLM_ARCH_GEMMA_EMBEDDING, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { @@ -1532,6 +1587,31 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_NEMOTRON_H, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + // mamba(2) ssm layers + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + // attention layers + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + // dense FFN + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_EXAONE, { @@ -2010,6 +2090,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, } }, { @@ -2067,6 +2148,43 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_LLADA_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_SEED_OSS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2319,6 +2437,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_LFM2: + case LLM_ARCH_NEMOTRON_H: return true; default: return false; @@ -2329,6 +2448,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { switch (arch) { case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: + case LLM_ARCH_LLADA_MOE: return true; default: return false; diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 7af587e7951..d181ce6784f 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -26,6 +26,7 @@ enum llm_arch { LLM_ARCH_NOMIC_BERT_MOE, LLM_ARCH_NEO_BERT, LLM_ARCH_JINA_BERT_V2, + LLM_ARCH_JINA_BERT_V3, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, @@ -48,6 +49,7 @@ enum llm_arch { LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, + LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_MAMBA2, @@ -72,6 +74,7 @@ enum llm_arch { LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, + LLM_ARCH_NEMOTRON_H, LLM_ARCH_EXAONE, LLM_ARCH_EXAONE4, LLM_ARCH_RWKV6, @@ -97,6 +100,8 @@ enum llm_arch { LLM_ARCH_DREAM, LLM_ARCH_SMALLTHINKER, LLM_ARCH_LLADA, + LLM_ARCH_LLADA_MOE, + LLM_ARCH_SEED_OSS, LLM_ARCH_UNKNOWN, }; @@ -137,7 +142,9 @@ enum llm_kv { LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, + LLM_KV_DECODER_BLOCK_COUNT, LLM_KV_ATTN_LOGIT_SOFTCAPPING, + LLM_KV_ROUTER_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING, LLM_KV_SWIN_NORM, LLM_KV_RESCALE_EVERY_N_LAYERS, @@ -168,6 +175,8 @@ enum llm_kv { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SCALE, + LLM_KV_ATTENTION_OUTPUT_SCALE, + LLM_KV_ATTENTION_TEMPERATURE_LENGTH, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, @@ -181,6 +190,10 @@ enum llm_kv { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, LLM_KV_ROPE_SCALING_YARN_LOG_MUL, + LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, + LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, + LLM_KV_ROPE_SCALING_YARN_BETA_FAST, + LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, LLM_KV_SPLIT_NO, LLM_KV_SPLIT_COUNT, @@ -229,6 +242,9 @@ enum llm_kv { LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, + LLM_KV_ADAPTER_LORA_TASK_NAME, + LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, + LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, LLM_KV_POSNET_EMBEDDING_LENGTH, LLM_KV_POSNET_BLOCK_COUNT, diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 0a96a9a579e..66e6c6a38f1 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -16,10 +16,10 @@ static std::string trim(const std::string & str) { size_t start = 0; size_t end = str.size(); - while (start < end && isspace(str[start])) { + while (start < end && isspace(static_cast(str[start]))) { start += 1; } - while (end > start && isspace(str[end - 1])) { + while (end > start && isspace(static_cast(str[end - 1]))) { end -= 1; } return str.substr(start, end - start); @@ -69,6 +69,8 @@ static const std::map LLM_CHAT_TEMPLATES = { { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, + { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, + { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -201,6 +203,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { return LLM_CHAT_TEMPLATE_KIMI_K2; + } else if (tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_SEED_OSS; + } else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) { + return LLM_CHAT_TEMPLATE_GROK_2; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -752,6 +758,28 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|im_assistant|>assistant<|im_middle|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_SEED_OSS) { + for (auto message: chat) { + std::string role(message->role); + ss << "" << role << "\n" << (role == "assistant" ? trim(message->content) : message->content) << ""; + } + if (add_ass) { + ss << "assistant\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GROK_2) { + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "System: " << trim(message->content) << "<|separator|>\n\n"; + } else if (role == "user") { + ss << "Human: " << trim(message->content) << "<|separator|>\n\n"; + } else if (role == "assistant") { + ss << "Assistant: " << message->content << "<|separator|>\n\n"; + } + } + if (add_ass) { + ss << "Assistant:"; + } } else { // template not supported return -1; diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index 35a943856fa..5a87d9ab627 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -49,6 +49,8 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, LLM_CHAT_TEMPLATE_KIMI_K2, + LLM_CHAT_TEMPLATE_SEED_OSS, + LLM_CHAT_TEMPLATE_GROK_2, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 7d7abad5d4a..e6f76421cf1 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -35,14 +35,12 @@ llama_context::llama_context( cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; - cparams.yarn_ext_factor = params.yarn_ext_factor; - cparams.yarn_attn_factor = params.yarn_attn_factor; - cparams.yarn_beta_fast = params.yarn_beta_fast; - cparams.yarn_beta_slow = params.yarn_beta_slow; - cparams.defrag_thold = params.defrag_thold; + cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; + cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; + cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; + cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; - cparams.flash_attn = params.flash_attn; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; cparams.warmup = false; @@ -87,13 +85,15 @@ llama_context::llama_context( cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; } + cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; + // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) // ref: https://github.com/ggerganov/llama.cpp/pull/5021 - // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self + // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory if (cparams.n_batch < GGML_KQ_MASK_PAD) { LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); cparams.n_batch = GGML_KQ_MASK_PAD; @@ -103,16 +103,6 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; - { - const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); - supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows; - - if (!supports_set_rows && !cparams.kv_unified) { - LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__); - cparams.kv_unified = true; - } - } - { const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; @@ -130,7 +120,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); - LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type)); LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -145,11 +135,6 @@ llama_context::llama_context( __func__, n_ctx_per_seq, hparams.n_ctx_train); } - if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) { - LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n", - __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573"); - } - if (!hparams.vocab_only) { // GPU backends for (auto * dev : model.devices) { @@ -196,7 +181,7 @@ llama_context::llama_context( // graph outputs buffer { // resized during inference when a batch uses more outputs - if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) { + if (output_reserve(params.n_seq_max) < params.n_seq_max) { throw std::runtime_error("failed to reserve initial output buffer"); } @@ -285,28 +270,75 @@ llama_context::llama_context( } } - // reserve worst-case graph - if (!hparams.vocab_only && memory) { + if (!hparams.vocab_only) { + llama_memory_context_ptr mctx; + if (memory) { + LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); + mctx = memory->init_full(); + if (!mctx) { + throw std::runtime_error("failed to initialize memory module"); + } + } + + cross.v_embd.clear(); + const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + // avoid reserving graphs with zero outputs - assume one output per sequence + n_outputs = n_seqs; + LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + // resolve automatic Flash Attention use + if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to split graph for Flash Attention check"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; + bool fa_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_FLASH_ATTN_EXT) { + continue; + } + ggml_backend_dev_t device_fa = ggml_backend_get_device( + ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_fa != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); + // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways + fa_device_mismatch = true; + break; + } + } + if (fa_device_mismatch) { + cparams.flash_attn = false; + LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); + if (ggml_is_quantized(params.type_v)) { + throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); + } + } else { + cparams.flash_attn = true; + LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); + } + } + + // reserve worst-case graph int n_splits_pp = -1; int n_nodes_pp = -1; int n_splits_tg = -1; int n_nodes_tg = -1; - // simulate full KV cache - - const auto mctx = memory->init_full(); - if (!mctx) { - throw std::runtime_error("failed to initialize KV cache"); - } - - cross.v_embd.clear(); - // reserve pp (prompt processing) graph first so that buffers are only allocated once { auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); @@ -444,26 +476,12 @@ llama_memory_t llama_context::get_memory() const { return memory.get(); } -// deprecated -void llama_context::kv_self_defrag_sched() { - if (!memory) { - return; - } - - memory_force_optimize = true; -} - -// deprecated -bool llama_context::kv_self_update(bool optimize) { +bool llama_context::memory_update(bool optimize) { if (!memory) { return false; } { - // TODO: remove in the future - optimize |= memory_force_optimize; - memory_force_optimize = false; - const auto mctx = memory->init_update(this, optimize); switch (mctx->get_status()) { case LLAMA_MEMORY_STATUS_SUCCESS: @@ -908,12 +926,6 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - if (!supports_set_rows) { - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to - // overlap with device computation. - ggml_backend_sched_reset(sched.get()); - } - // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -997,8 +1009,8 @@ int llama_context::decode(const llama_batch & batch_inp) { bool did_optimize = false; - // handle any pending defrags/shifts - kv_self_update(false); + // handle any pending shifts/copies + memory_update(false); llama_memory_context_ptr mctx; @@ -1023,7 +1035,7 @@ int llama_context::decode(const llama_batch & batch_inp) { if (!did_optimize) { did_optimize = true; - if (kv_self_update(true)) { + if (memory_update(true)) { LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens()); continue; @@ -1076,7 +1088,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); if (!res) { - // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache + // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module llama_pos pos_min[LLAMA_MAX_SEQ]; for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { pos_min[s] = std::numeric_limits::max(); @@ -1093,7 +1105,7 @@ int llama_context::decode(const llama_batch & batch_inp) { continue; } - LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); + LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); memory->seq_rm(s, pos_min[s], -1); } @@ -1244,12 +1256,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); - if (!supports_set_rows) { - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to - // overlap with device computation. - ggml_backend_sched_reset(sched.get()); - } - return 0; } @@ -1363,8 +1369,9 @@ llm_graph_result * llama_context::get_gf_res_reserve() const { return static_cast(gf_res_reserve.get()); } -ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) { +ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) { LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); + GGML_ASSERT(n_outputs >= 1); if (n_tokens % n_seqs != 0) { n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs @@ -1398,7 +1405,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u this->n_outputs = save_n_outputs; // initialize scheduler with the specified graph - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + if (split_only) { + ggml_backend_sched_split_graph(sched.get(), gf); + } else if (!ggml_backend_sched_reserve(sched.get(), gf)) { LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); return nullptr; } @@ -1438,7 +1447,9 @@ ggml_status llama_context::graph_compute( if (backend_cpu != nullptr) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); - set_threadpool_fn(backend_cpu, tp); + if (set_threadpool_fn) { + set_threadpool_fn(backend_cpu, tp); + } } // set the number of threads for all the backends @@ -1877,7 +1888,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } if (memory != nullptr) { - LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); + LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); memory->state_write(io); } @@ -1963,7 +1974,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { } if (memory) { - LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); + LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); memory->state_read(io); } @@ -2248,12 +2259,13 @@ llama_context_params llama_context_default_params() { /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, + /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, - /*.yarn_attn_factor =*/ 1.0f, - /*.yarn_beta_fast =*/ 32.0f, - /*.yarn_beta_slow =*/ 1.0f, + /*.yarn_attn_factor =*/ -1.0f, + /*.yarn_beta_fast =*/ -1.0f, + /*.yarn_beta_slow =*/ -1.0f, /*.yarn_orig_ctx =*/ 0, /*.defrag_thold =*/ -1.0f, /*.cb_eval =*/ nullptr, @@ -2264,7 +2276,6 @@ llama_context_params llama_context_default_params() { /*.abort_callback_data =*/ nullptr, /*.embeddings =*/ false, /*.offload_kqv =*/ true, - /*.flash_attn =*/ false, /*.no_perf =*/ true, /*.op_offload =*/ true, /*.swa_full =*/ true, @@ -2292,12 +2303,30 @@ llama_context * llama_init_from_model( return nullptr; } - if (params.flash_attn && model->arch == LLM_ARCH_GROK) { + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) { LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); - params.flash_attn = false; + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + } + + if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { + const uint32_t blck_size = ggml_blck_size(params.type_k); + if (model->hparams.n_embd_head_k % blck_size != 0) { + LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", + __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k); + return nullptr; + } + } + + if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { + const uint32_t blck_size = ggml_blck_size(params.type_v); + if (model->hparams.n_embd_head_v % blck_size != 0) { + LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n", + __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v); + return nullptr; + } } - if (ggml_is_quantized(params.type_v) && !params.flash_attn) { + if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); return nullptr; } @@ -2343,16 +2372,6 @@ const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } -// deprecated -llama_kv_cache * llama_get_kv_self(llama_context * ctx) { - return dynamic_cast(ctx->get_memory()); -} - -// deprecated -void llama_kv_self_update(llama_context * ctx) { - ctx->kv_self_update(false); -} - enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { return ctx->pooling_type(); } @@ -2570,168 +2589,6 @@ bool llama_memory_can_shift(llama_memory_t mem) { return mem->get_can_shift(); } -// -// kv cache -// - -// deprecated -int32_t llama_kv_self_n_tokens(const llama_context * ctx) { - const auto * kv = llama_get_memory(ctx); - if (!kv) { - return 0; - } - - int32_t res = 0; - - for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) { - const llama_pos p0 = kv->seq_pos_min(s); - const llama_pos p1 = kv->seq_pos_max(s); - - if (p0 >= 0) { - res += (p1 - p0) + 1; - } - } - - return res; -} - -// deprecated -// note: this is the same as above - will be removed anyway, so it's ok -int32_t llama_kv_self_used_cells(const llama_context * ctx) { - const auto * kv = llama_get_memory(ctx); - if (!kv) { - return 0; - } - - int32_t res = 0; - - for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) { - const llama_pos p0 = kv->seq_pos_min(s); - const llama_pos p1 = kv->seq_pos_max(s); - - if (p0 >= 0) { - res += (p1 - p0) + 1; - } - } - - return res; -} - -// deprecated -void llama_kv_self_clear(llama_context * ctx) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return; - } - - llama_memory_clear(kv, true); -} - -// deprecated -bool llama_kv_self_seq_rm( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return true; - } - - return llama_memory_seq_rm(kv, seq_id, p0, p1); -} - -// deprecated -void llama_kv_self_seq_cp( - llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return; - } - - llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1); -} - -// deprecated -void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return; - } - - llama_memory_seq_keep(kv, seq_id); -} - -// deprecated -void llama_kv_self_seq_add( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return; - } - - llama_memory_seq_add(kv, seq_id, p0, p1, delta); -} - -// deprecated -void llama_kv_self_seq_div( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return; - } - - llama_memory_seq_div(kv, seq_id, p0, p1, d); -} - -// deprecated -llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return -1; - } - - return llama_memory_seq_pos_min(kv, seq_id); -} - -// deprecated -llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return -1; - } - - return llama_memory_seq_pos_max(kv, seq_id); -} - -// deprecated -void llama_kv_self_defrag(llama_context * ctx) { - // force defrag - ctx->kv_self_defrag_sched(); -} - -// deprecated -bool llama_kv_self_can_shift(const llama_context * ctx) { - auto * kv = llama_get_memory(ctx); - if (!kv) { - return false; - } - - return llama_memory_can_shift(kv); -} - // llama state API // deprecated diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 230ef8962b8..f23aa8ee136 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -46,10 +46,8 @@ struct llama_context { llama_memory_t get_memory() const; - // return true of the KV cache was updated - // TODO: remove - bool kv_self_update(bool optimize); - void kv_self_defrag_sched(); + // return true if the memory was updated + bool memory_update(bool optimize); enum llama_pooling_type pooling_type() const; @@ -198,7 +196,7 @@ struct llama_context { ggml_status graph_compute(ggml_cgraph * gf, bool batched); // reserve a graph with a dummy ubatch of the specified size - ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); + ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false); private: llm_graph_params graph_params( @@ -230,9 +228,6 @@ struct llama_context { std::unique_ptr memory; - // TODO: temporary, until the llama_kv_self_defrag() API is removed - bool memory_force_optimize = false; - // decode output (2-dimensional array: [n_outputs][n_vocab]) size_t logits_size = 0; // capacity (of floats) for logits float * logits = nullptr; @@ -288,10 +283,6 @@ struct llama_context { bool has_evaluated_once = false; - // env: LLAMA_SET_ROWS (temporary) - // ref: https://github.com/ggml-org/llama.cpp/pull/14285 - bool supports_set_rows = true; - // env: LLAMA_GRAPH_REUSE_DISABLE bool graph_reuse_disable = false; diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 38750affc50..eae7b839f48 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -4,7 +4,7 @@ #include -#define LLAMA_MAX_SEQ 64 +#define LLAMA_MAX_SEQ 256 struct llama_cparams { uint32_t n_ctx; // context size used during inference @@ -24,7 +24,6 @@ struct llama_cparams { float yarn_attn_factor; float yarn_beta_fast; float yarn_beta_slow; - float defrag_thold; bool embeddings; bool causal_attn; diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 053c72d6dc8..9f2e417f1ff 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -4,8 +4,8 @@ #include "llama-batch.h" #include "llama-cparams.h" -#include "llama-kv-cache-unified.h" -#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache.h" +#include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" #include "llama-memory-recurrent.h" @@ -258,6 +258,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } +static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { + LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); + const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" : + (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" : + (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" : + (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown"; + LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); + LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); + LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); + + LLAMA_LOG_DEBUG(" "); + for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) { + LLAMA_LOG_DEBUG("%2d", j); + } + LLAMA_LOG_DEBUG("\n"); + + for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) { + LLAMA_LOG_DEBUG(" %2d ", i); + for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) { + float val = data[i * n_kv + j]; + if (val == -INFINITY) { + LLAMA_LOG_DEBUG(" ∞"); + } else { + LLAMA_LOG_DEBUG(" 0"); + } + } + LLAMA_LOG_DEBUG("\n"); + } +} + void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_kv = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens; @@ -267,6 +297,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { float * data = (float *) kq_mask->data; + // [TAG_NO_CACHE_ISWA] + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement"); + for (int h = 0; h < 1; ++h) { for (int i1 = 0; i1 < n_tokens; ++i1) { const llama_seq_id s1 = ubatch->seq_id[i1][0]; @@ -277,32 +310,44 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) { const llama_seq_id s0 = ubatch->seq_id[i0][0]; + if (s0 != s1) { + continue; // skip different sequences + } + + if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) { + continue; // skip future tokens for causal attention + } + + // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA] + //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) { + // continue; // skip masked tokens for SWA + //} + // TODO: reimplement this like in llama_kv_cache_unified - if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) { - if (hparams.use_alibi) { - f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]); - } else { - f = 0.0f; - } - break; + if (hparams.use_alibi) { + f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]); + } else { + f = 0.0f; } } - data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f; } } } + if (debug) { + print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + } } -void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { +void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { mctx->set_input_k_idxs(self_k_idxs, ubatch); mctx->set_input_v_idxs(self_v_idxs, ubatch); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } -bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) { - const auto * mctx = static_cast(params.mctx); +bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); this->mctx = mctx; @@ -314,12 +359,10 @@ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) res &= self_kq_mask->ne[0] == mctx->get_n_kv(); res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); - res &= mctx->get_supports_set_rows(); // TODO: tmp - return res; } -void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { +void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); @@ -331,8 +374,8 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } -bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) { - const auto * mctx = static_cast(params.mctx); +bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); this->mctx = mctx; @@ -350,8 +393,6 @@ bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & pa res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); - res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp - return res; } @@ -1186,7 +1227,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { } ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { - const auto * mctx_cur = static_cast(mctx); + const auto * mctx_cur = static_cast(mctx); auto inp = std::make_unique(hparams, mctx_cur); @@ -1223,15 +1264,16 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * v, ggml_tensor * kq_b, ggml_tensor * kq_mask, - ggml_tensor * v_mla, ggml_tensor * sinks, - float kq_scale) const { + ggml_tensor * v_mla, + float kq_scale, + int il) const { const bool v_trans = v->nb[1] > v->nb[2]; // split the batch into streams if needed const auto n_stream = k->ne[3]; - q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream); + q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0); q = ggml_permute(ctx0, q, 0, 2, 1, 3); k = ggml_permute(ctx0, k, 0, 2, 1, 3); @@ -1260,6 +1302,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); + cb(cur, LLAMA_TENSOR_NAME_FATTN, il); ggml_flash_attn_ext_add_sinks(cur, sinks); ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32); @@ -1275,6 +1318,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( // The permutations are noops and only change how the tensor data is interpreted. cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); cur = ggml_mul_mat(ctx0, v_mla, cur); + cb(cur, "fattn_mla", il); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs. #endif @@ -1283,6 +1327,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); } else { ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); // note: this op tends to require high floating point range // while for some models F16 is enough, for others it is not, so we default to F32 here @@ -1290,38 +1335,48 @@ ggml_tensor * llm_graph_context::build_attn_mha( if (arch == LLM_ARCH_GROK) { // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 + // multiply by attn_output_multiplier // and then : // kq = 30 * tanh(kq / 30) // before the softmax below - kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f)); - kq = ggml_scale(ctx0, kq, 30); + kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping)); + cb(kq, "kq_tanh", il); + kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping); + cb(kq, "kq_scaled", il); } if (hparams.attn_soft_cap) { kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping); + cb(kq, "kq_scaled_1", il); kq = ggml_tanh (ctx0, kq); + cb(kq, "kq_tanh", il); kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping); + cb(kq, "kq_scaled_2", il); } if (kq_b) { kq = ggml_add(ctx0, kq, kq_b); + cb(kq, "kq_plus_kq_b", il); } kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); ggml_soft_max_add_sinks(kq, sinks); + cb(kq, "kq_soft_max", il); if (!v_trans) { // note: avoid this branch v = ggml_cont(ctx0, ggml_transpose(ctx0, v)); + cb(v, "v_cont", il); } ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cb(kqv, "kqv", il); // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA if (v_mla) { kqv = ggml_mul_mat(ctx0, v_mla, kqv); + cb(kqv, "kqv_mla", il); } cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); @@ -1360,6 +1415,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * sinks, ggml_tensor * v_mla, float kq_scale, int il) const { @@ -1375,13 +1431,14 @@ ggml_tensor * llm_graph_context::build_attn( // [TAG_NO_CACHE_PAD] // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams - assert(!ubatch.equal_seqs()); + // but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636 + //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq)); ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; ggml_tensor * v = v_cur; - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); if (wo) { @@ -1399,17 +1456,17 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -static std::unique_ptr build_attn_inp_kv_unified_impl( +static std::unique_ptr build_attn_inp_kv_impl( ggml_context * ctx0, const llama_ubatch & ubatch, const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_context * mctx_cur) { + const llama_kv_cache_context * mctx_cur) { - auto inp = std::make_unique(hparams, cparams, mctx_cur); + auto inp = std::make_unique(hparams, cparams, mctx_cur); { - GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); const auto n_kv = mctx_cur->get_n_kv(); const auto n_tokens = ubatch.n_tokens; @@ -1427,22 +1484,23 @@ static std::unique_ptr build_attn_inp_kv_unifie return inp; } -llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { - const auto * mctx_cur = static_cast(mctx); +llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const { + const auto * mctx_cur = static_cast(mctx); - auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur); - return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); + return (llm_graph_input_attn_kv *) res->add_input(std::move(inp)); } ggml_tensor * llm_graph_context::build_attn( - llm_graph_input_attn_kv_unified * inp, + llm_graph_input_attn_kv * inp, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * sinks, ggml_tensor * v_mla, float kq_scale, int il) const { @@ -1469,7 +1527,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = mctx_cur->get_k(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il); - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); if (wo) { @@ -1488,40 +1546,15 @@ ggml_tensor * llm_graph_context::build_attn( } ggml_tensor * llm_graph_context::build_attn( - llm_graph_input_attn_kv_unified_iswa * inp, + llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, - ggml_tensor * v_mla, - float kq_scale, - int il) const { - return build_attn_with_sinks( - inp, - wo, - wo_b, - q_cur, - k_cur, - v_cur, - kq_b, - v_mla, - nullptr, - kq_scale, - il); -} - -ggml_tensor * llm_graph_context::build_attn_with_sinks( - llm_graph_input_attn_kv_unified_iswa * inp, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - ggml_tensor * kq_b, - ggml_tensor * v_mla, ggml_tensor * sinks, + ggml_tensor * v_mla, float kq_scale, int il) const { // these nodes are added to the graph together so that they are not reordered @@ -1561,7 +1594,7 @@ ggml_tensor * llm_graph_context::build_attn_with_sinks( ggml_tensor * k = mctx_cur->get_k(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il); - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); if (wo) { @@ -1600,6 +1633,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * sinks, ggml_tensor * v_mla, float kq_scale, int il) const { @@ -1615,7 +1649,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = k_cur; ggml_tensor * v = v_cur; - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); if (wo) { @@ -1636,10 +1670,10 @@ ggml_tensor * llm_graph_context::build_attn( // TODO: maybe separate the inner implementation into a separate function // like with the non-sliding window equivalent // once sliding-window hybrid caches are a thing. -llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { - const auto * mctx_cur = static_cast(mctx); +llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const { + const auto * mctx_cur = static_cast(mctx); - auto inp = std::make_unique(hparams, cparams, mctx_cur); + auto inp = std::make_unique(hparams, cparams, mctx_cur); const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; @@ -1656,7 +1690,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif } { - GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); + GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA"); const auto n_kv = mctx_cur->get_swa()->get_n_kv(); @@ -1669,7 +1703,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; } - return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp)); + return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); } ggml_tensor * llm_graph_context::build_rs( @@ -1792,7 +1826,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); - auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); + auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 6ff49de3a1c..ca90fdf613f 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -19,8 +19,8 @@ struct llama_cparams; struct llama_memory_context_i; -class llama_kv_cache_unified_context; -class llama_kv_cache_unified_iswa_context; +class llama_kv_cache_context; +class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; @@ -78,6 +78,11 @@ struct llm_graph_params; class llm_graph_input_i { public: + llm_graph_input_i() { + const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG"); + debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0; + } + virtual ~llm_graph_input_i() = default; virtual void set_input(const llama_ubatch * ubatch) = 0; @@ -90,6 +95,9 @@ class llm_graph_input_i { GGML_UNUSED(params); return false; } +protected: + // env: LLAMA_GRAPH_INPUT_DEBUG + int debug = 0; }; using llm_graph_input_ptr = std::unique_ptr; @@ -152,7 +160,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { public: llm_graph_input_pos_bucket_kv( const llama_hparams & hparams, - const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {} + const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {} virtual ~llm_graph_input_pos_bucket_kv() = default; void set_input(const llama_ubatch * ubatch) override; @@ -161,7 +169,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { const llama_hparams hparams; - const llama_kv_cache_unified_context * mctx; + const llama_kv_cache_context * mctx; }; class llm_graph_input_out_ids : public llm_graph_input_i { @@ -257,17 +265,17 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i { const llama_cparams cparams; }; -class llm_graph_input_attn_kv_unified : public llm_graph_input_i { +class llm_graph_input_attn_kv : public llm_graph_input_i { public: - llm_graph_input_attn_kv_unified( + llm_graph_input_attn_kv( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_context * mctx) : + const llama_kv_cache_context * mctx) : hparams(hparams), cparams(cparams), mctx(mctx) { } - ~llm_graph_input_attn_kv_unified() = default; + ~llm_graph_input_attn_kv() = default; void set_input(const llama_ubatch * ubatch) override; @@ -290,20 +298,20 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { const llama_hparams hparams; const llama_cparams cparams; - const llama_kv_cache_unified_context * mctx; + const llama_kv_cache_context * mctx; }; -class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { +class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { public: - llm_graph_input_attn_kv_unified_iswa( + llm_graph_input_attn_kv_iswa( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified_iswa_context * mctx) : + const llama_kv_cache_iswa_context * mctx) : hparams(hparams), cparams(cparams), mctx(mctx) { } - ~llm_graph_input_attn_kv_unified_iswa() = default; + ~llm_graph_input_attn_kv_iswa() = default; void set_input(const llama_ubatch * ubatch) override; @@ -330,7 +338,7 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { const llama_hparams hparams; const llama_cparams cparams; - const llama_kv_cache_unified_iswa_context * mctx; + const llama_kv_cache_iswa_context * mctx; }; class llm_graph_input_attn_cross : public llm_graph_input_i { @@ -351,7 +359,7 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_mem_hybrid : public llm_graph_input_i { public: llm_graph_input_mem_hybrid( - std::unique_ptr inp_attn, + std::unique_ptr inp_attn, std::unique_ptr inp_rs, const llama_memory_hybrid_context * mctx) : inp_attn(std::move(inp_attn)), @@ -361,11 +369,11 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; - std::unique_ptr inp_attn; - std::unique_ptr inp_rs; + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; - llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); } - llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } const llama_memory_hybrid_context * mctx; }; @@ -680,14 +688,15 @@ struct llm_graph_context { // ggml_tensor * build_attn_mha( - ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens] - ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens] - ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false) - ggml_tensor * kq_b, - ggml_tensor * kq_mask, - ggml_tensor * sinks, - ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] - float kq_scale) const; + ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false) + ggml_tensor * kq_b, + ggml_tensor * kq_mask, + ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const; @@ -699,50 +708,39 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; - llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const; + llm_graph_input_attn_kv * build_attn_inp_kv() const; ggml_tensor * build_attn( - llm_graph_input_attn_kv_unified * inp, + llm_graph_input_attn_kv * inp, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; - llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const; + llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const; // note: if k_cur or v_cur are not provided, they will not be stored in the memory ggml_tensor * build_attn( - llm_graph_input_attn_kv_unified_iswa * inp, + llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional ggml_tensor * kq_b, - ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] - float kq_scale, - int il) const; - - // TODO: temporary to keep the diff small. after the code is public will refactor to simplify this - ggml_tensor * build_attn_with_sinks( - llm_graph_input_attn_kv_unified_iswa * inp, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] - ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional - ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional - ggml_tensor * kq_b, - ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -756,6 +754,7 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -765,7 +764,7 @@ struct llm_graph_context { // // TODO: move this implementation to llama_memory_recurrent. - // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v + // this is analogous to llama_kv_cache::cpy_k / cpy_v // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in // `llama_memory_recurrent` diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 7a06368dcda..c04ac58f1af 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -1,6 +1,7 @@ #include "llama-hparams.h" #include "ggml.h" +#include void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { if (dense_first) { @@ -153,3 +154,64 @@ bool llama_hparams::is_swa(uint32_t il) const { GGML_ABORT("fatal error"); } + +bool llama_hparams::has_kv(uint32_t il) const { + if (n_layer_kv_from_start >= 0) { + if (il < (uint32_t) n_layer_kv_from_start) { + return true; + } + + return false; + } + + // by default, all layers have kv + return true; +} + +uint32_t llama_hparams::n_layer_kv() const { + uint32_t res = 0; + + for (uint32_t il = 0; il < n_layer; ++il) { + if (has_kv(il)) { + res++; + } + } + + return res; +} + +bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + case LLAMA_SWA_TYPE_SYMMETRIC: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // Mask if outside the symmetric window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + } + + return false; +} diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index bd231224432..202cbbd1b28 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -16,9 +16,10 @@ enum llama_expert_gating_func_type { }; enum llama_swa_type { - LLAMA_SWA_TYPE_NONE = 0, - LLAMA_SWA_TYPE_STANDARD = 1, - LLAMA_SWA_TYPE_CHUNKED = 2, + LLAMA_SWA_TYPE_NONE = 0, + LLAMA_SWA_TYPE_STANDARD = 1, + LLAMA_SWA_TYPE_CHUNKED = 2, + LLAMA_SWA_TYPE_SYMMETRIC = 3, }; struct llama_hparams_posnet { @@ -41,6 +42,7 @@ struct llama_hparams { uint32_t n_embd; uint32_t n_embd_features = 0; uint32_t n_layer; + int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_rot; uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head @@ -80,8 +82,9 @@ struct llama_hparams { float f_norm_rms_eps; float f_norm_group_eps; - float f_attn_logit_softcapping = 50.0f; - float f_final_logit_softcapping = 30.0f; + float f_attn_logit_softcapping = 50.0f; + float f_router_logit_softcapping = 30.0f; + float f_final_logit_softcapping = 30.0f; // for RWKV uint32_t rescale_every_n_layers = 0; @@ -102,6 +105,11 @@ struct llama_hparams { uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; + float yarn_ext_factor = -1.0f; + float yarn_attn_factor = 1.0f; + float yarn_beta_fast = 32.0f; + float yarn_beta_slow = 1.0f; + std::array rope_sections; // Sliding Window Attention (SWA) @@ -134,10 +142,14 @@ struct llama_hparams { float f_embedding_scale = 0.0f; float f_attention_scale = 0.0f; + // grok-2 + float f_attn_out_scale = 0.0f; + uint32_t attn_temp_length = 0; + bool causal_attn = true; bool use_alibi = false; bool attn_soft_cap = false; - bool use_kq_norm = true; + bool use_kq_norm = false; // for Classifiers uint32_t n_cls_out = 1; @@ -157,6 +169,7 @@ struct llama_hparams { // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; + uint32_t dec_n_layer = 0; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -221,6 +234,16 @@ struct llama_hparams { uint32_t n_pos_per_embd() const; bool is_swa(uint32_t il) const; + + bool has_kv(uint32_t il) const; + + // number of layers for which has_kv() returns true + uint32_t n_layer_kv() const; + + // note that this function uses different SWA parameters from those in the hparams + // TODO: think of a better place for this function + // TODO: pack the SWA params in a struct? + static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); diff --git a/examples/talk-llama/llama-impl.h b/examples/talk-llama/llama-impl.h index 02b1d07f840..c5163e9225a 100644 --- a/examples/talk-llama/llama-impl.h +++ b/examples/talk-llama/llama-impl.h @@ -59,3 +59,5 @@ std::string llama_format_tensor_shape(const std::vector & ne); std::string llama_format_tensor_shape(const struct ggml_tensor * t); std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); + +#define LLAMA_TENSOR_NAME_FATTN "__fattn__" diff --git a/examples/talk-llama/llama-kv-cache-unified-iswa.cpp b/examples/talk-llama/llama-kv-cache-iswa.cpp similarity index 59% rename from examples/talk-llama/llama-kv-cache-unified-iswa.cpp rename to examples/talk-llama/llama-kv-cache-iswa.cpp index 1e363fff2a5..d7342914c6b 100644 --- a/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +++ b/examples/talk-llama/llama-kv-cache-iswa.cpp @@ -1,4 +1,4 @@ -#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache-iswa.h" #include "llama-impl.h" #include "llama-batch.h" @@ -8,10 +8,10 @@ #include // -// llama_kv_cache_unified_iswa +// llama_kv_cache_iswa // -llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( +llama_kv_cache_iswa::llama_kv_cache_iswa( const llama_model & model, ggml_type type_k, ggml_type type_v, @@ -22,9 +22,26 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( uint32_t kv_size, uint32_t n_seq_max, uint32_t n_ubatch, - uint32_t n_pad) : hparams(model.hparams), unified(unified) { - llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; - llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; + uint32_t n_pad, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) { + + // chain filters + const layer_filter_cb filter_base = [&](int32_t il) { + if (filter && !filter(il)) { + return false; + } + + return !model.hparams.is_swa(il); + }; + + const layer_filter_cb filter_swa = [&](int32_t il) { + if (filter && !filter(il)) { + return false; + } + + return model.hparams.is_swa(il); + }; const uint32_t size_base = kv_size; @@ -40,25 +57,25 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); - kv_base = std::make_unique( - model, std::move(filter_base), type_k, type_v, + kv_base = std::make_unique( + model, type_k, type_v, v_trans, offload, unified, size_base, n_seq_max, n_pad, - 0, LLAMA_SWA_TYPE_NONE); + 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); - kv_swa = std::make_unique( - model, std::move(filter_swa), type_k, type_v, + kv_swa = std::make_unique( + model, type_k, type_v, v_trans, offload, unified, size_swa, n_seq_max, n_pad, - hparams.n_swa, hparams.swa_type); + hparams.n_swa, hparams.swa_type, filter_swa, reuse); } -void llama_kv_cache_unified_iswa::clear(bool data) { +void llama_kv_cache_iswa::clear(bool data) { kv_base->clear(data); kv_swa ->clear(data); } -bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { +bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { bool res = true; res = res & kv_base->seq_rm(seq_id, p0, p1); @@ -67,36 +84,36 @@ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llam return res; } -void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); } -void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { +void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) { kv_base->seq_keep(seq_id); kv_swa ->seq_keep(seq_id); } -void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { +void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { kv_base->seq_add(seq_id, p0, p1, shift); kv_swa ->seq_add(seq_id, p0, p1, shift); } -void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { +void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { kv_base->seq_div(seq_id, p0, p1, d); kv_swa ->seq_div(seq_id, p0, p1, d); } -llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const { +llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const { // the base cache is a superset of the SWA cache, so we can just check the SWA cache return kv_swa->seq_pos_min(seq_id); } -llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { +llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const { return kv_swa->seq_pos_max(seq_id); } -llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { +llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { GGML_UNUSED(embd_all); // first try simple split @@ -136,7 +153,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all assert(sinfos_base.size() == sinfos_swa.size()); - return std::make_unique( + return std::make_unique( this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); } while (false); @@ -172,29 +189,29 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all assert(sinfos_base.size() == sinfos_swa.size()); - return std::make_unique( + return std::make_unique( this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); } while (false); // TODO: if we fail again, we should attempt different splitting strategies // but to do that properly, we first have to refactor the batches to be more flexible - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } -llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() { - return std::make_unique(this); +llama_memory_context_ptr llama_kv_cache_iswa::init_full() { + return std::make_unique(this); } -llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { - return std::make_unique(this, lctx, optimize); +llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); } -bool llama_kv_cache_unified_iswa::get_can_shift() const { +bool llama_kv_cache_iswa::get_can_shift() const { return kv_base->get_size() == kv_swa->get_size(); } -void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { +void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { kv_base->state_write(io, seq_id, flags); } @@ -202,7 +219,7 @@ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_i kv_swa->state_write(io, seq_id, flags); } -void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { +void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { kv_base->state_read(io, seq_id, flags); } @@ -210,29 +227,29 @@ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id kv_swa->state_read(io, seq_id, flags); } -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { +llama_kv_cache * llama_kv_cache_iswa::get_base() const { return kv_base.get(); } -llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { +llama_kv_cache * llama_kv_cache_iswa::get_swa() const { return kv_swa.get(); } // -// llama_kv_cache_unified_iswa_context +// llama_kv_cache_iswa_context // -llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {} +llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {} -llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( - llama_kv_cache_unified_iswa * kv) : +llama_kv_cache_iswa_context::llama_kv_cache_iswa_context( + llama_kv_cache_iswa * kv) : ctx_base(kv->get_base()->init_full()), ctx_swa (kv->get_swa ()->init_full()), status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } -llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( - llama_kv_cache_unified_iswa * kv, +llama_kv_cache_iswa_context::llama_kv_cache_iswa_context( + llama_kv_cache_iswa * kv, llama_context * lctx, bool optimize) : ctx_base(kv->get_base()->init_update(lctx, optimize)), @@ -240,21 +257,21 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } -llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( - llama_kv_cache_unified_iswa * kv, +llama_kv_cache_iswa_context::llama_kv_cache_iswa_context( + llama_kv_cache_iswa * kv, slot_info_vec_t sinfos_base, slot_info_vec_t sinfos_swa, std::vector ubatches) : ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal - ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)), - ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)), + ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)), + ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)), status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } -llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default; +llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default; -bool llama_kv_cache_unified_iswa_context::next() { +bool llama_kv_cache_iswa_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); ctx_base->next(); @@ -267,7 +284,7 @@ bool llama_kv_cache_unified_iswa_context::next() { return true; } -bool llama_kv_cache_unified_iswa_context::apply() { +bool llama_kv_cache_iswa_context::apply() { assert(!llama_memory_status_is_fail(status)); bool res = true; @@ -278,24 +295,24 @@ bool llama_kv_cache_unified_iswa_context::apply() { return res; } -llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const { +llama_memory_status llama_kv_cache_iswa_context::get_status() const { return status; } -const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const { +const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return ubatches[i_next]; } -const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const { +const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return static_cast(ctx_base.get()); + return static_cast(ctx_base.get()); } -const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const { +const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return static_cast(ctx_swa.get()); + return static_cast(ctx_swa.get()); } diff --git a/examples/talk-llama/llama-kv-cache-unified-iswa.h b/examples/talk-llama/llama-kv-cache-iswa.h similarity index 68% rename from examples/talk-llama/llama-kv-cache-unified-iswa.h rename to examples/talk-llama/llama-kv-cache-iswa.h index 7bc4df718d3..5ed134b7958 100644 --- a/examples/talk-llama/llama-kv-cache-unified-iswa.h +++ b/examples/talk-llama/llama-kv-cache-iswa.h @@ -1,19 +1,19 @@ #pragma once -#include "llama-kv-cache-unified.h" +#include "llama-kv-cache.h" #include // -// llama_kv_cache_unified_iswa +// llama_kv_cache_iswa // -// utilizes two instances of llama_kv_cache_unified +// utilizes two instances of llama_kv_cache // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers -class llama_kv_cache_unified_iswa : public llama_memory_i { +class llama_kv_cache_iswa : public llama_memory_i { public: - llama_kv_cache_unified_iswa( + llama_kv_cache_iswa( const llama_model & model, ggml_type type_k, ggml_type type_v, @@ -24,9 +24,11 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { uint32_t kv_size, uint32_t n_seq_max, uint32_t n_ubatch, - uint32_t n_pad); + uint32_t n_pad, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse); - ~llama_kv_cache_unified_iswa() = default; + ~llama_kv_cache_iswa() = default; // // llama_memory_i @@ -60,46 +62,46 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; // - // llama_kv_cache_unified_iswa specific API + // llama_kv_cache_iswa specific API // - llama_kv_cache_unified * get_base() const; - llama_kv_cache_unified * get_swa () const; + llama_kv_cache * get_base() const; + llama_kv_cache * get_swa () const; private: const llama_hparams & hparams; const bool unified; - std::unique_ptr kv_base; - std::unique_ptr kv_swa; + std::unique_ptr kv_base; + std::unique_ptr kv_swa; }; -class llama_kv_cache_unified_iswa_context : public llama_memory_context_i { +class llama_kv_cache_iswa_context : public llama_memory_context_i { public: - using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; // used for errors - llama_kv_cache_unified_iswa_context(llama_memory_status status); + llama_kv_cache_iswa_context(llama_memory_status status); // used to create a full-cache context - llama_kv_cache_unified_iswa_context( - llama_kv_cache_unified_iswa * kv); + llama_kv_cache_iswa_context( + llama_kv_cache_iswa * kv); // used to create an update context - llama_kv_cache_unified_iswa_context( - llama_kv_cache_unified_iswa * kv, + llama_kv_cache_iswa_context( + llama_kv_cache_iswa * kv, llama_context * lctx, bool optimize); // used to create a batch processing context from a batch - llama_kv_cache_unified_iswa_context( - llama_kv_cache_unified_iswa * kv, + llama_kv_cache_iswa_context( + llama_kv_cache_iswa * kv, slot_info_vec_t sinfos_base, slot_info_vec_t sinfos_swa, std::vector ubatches); - virtual ~llama_kv_cache_unified_iswa_context(); + virtual ~llama_kv_cache_iswa_context(); // // llama_memory_context_i @@ -112,14 +114,14 @@ class llama_kv_cache_unified_iswa_context : public llama_memory_context_i { const llama_ubatch & get_ubatch() const override; // - // llama_kv_cache_unified_iswa_context specific API + // llama_kv_cache_iswa_context specific API // - const llama_kv_cache_unified_context * get_base() const; - const llama_kv_cache_unified_context * get_swa() const; + const llama_kv_cache_context * get_base() const; + const llama_kv_cache_context * get_swa() const; private: - //llama_kv_cache_unified_iswa * kv; + //llama_kv_cache_iswa * kv; // the index of the next ubatch to process size_t i_next = 0; diff --git a/examples/talk-llama/llama-kv-cache-unified.h b/examples/talk-llama/llama-kv-cache-unified.h deleted file mode 100644 index 07a7c9e4e46..00000000000 --- a/examples/talk-llama/llama-kv-cache-unified.h +++ /dev/null @@ -1,399 +0,0 @@ -#pragma once - -#include "llama-batch.h" -#include "llama-graph.h" -#include "llama-kv-cells.h" -#include "llama-memory.h" - -#include -#include - -struct llama_cparams; -struct llama_hparams; -struct llama_model; -struct llama_context; - -// -// llama_kv_cache_unified -// - -class llama_kv_cache_unified : public llama_memory_i { -public: - static uint32_t get_padding(const llama_cparams & cparams); - - // this callback is used to filter out layers that should not be included in the cache - using layer_filter_cb = std::function; - - struct defrag_info { - bool empty() const { - return ids.empty(); - } - - // contains information about which cell moves where: - // - cell i moves to ids[i] - // - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved - std::vector ids; - }; - - struct stream_copy_info { - bool empty() const { - assert(ssrc.size() == sdst.size()); - return ssrc.empty(); - } - - std::vector ssrc; - std::vector sdst; - }; - - // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the - // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] - struct slot_info { - // data for ggml_set_rows - using idx_vec_t = std::vector; - - // number of streams: ns = s1 - s0 + 1 - llama_seq_id s0; - llama_seq_id s1; - - std::vector strm; // [ns] - std::vector idxs; // [ns] - - uint32_t head() const { - GGML_ASSERT(idxs.size() == 1); - GGML_ASSERT(!idxs[0].empty()); - - return idxs[0][0]; - } - - void resize(size_t n) { - strm.resize(n); - idxs.resize(n); - } - - size_t size() const { - GGML_ASSERT(idxs.size() == strm.size()); - GGML_ASSERT(!idxs.empty()); - - return idxs[0].size(); - } - - size_t n_stream() const { - return strm.size(); - } - - bool empty() const { - return idxs.empty(); - } - - void clear() { - idxs.clear(); - } - }; - - using slot_info_vec_t = std::vector; - - llama_kv_cache_unified( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - bool unified, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type); - - ~llama_kv_cache_unified() = default; - - // - // llama_memory_i - // - - llama_memory_context_ptr init_batch( - llama_batch_allocr & balloc, - uint32_t n_ubatch, - bool embd_all) override; - - llama_memory_context_ptr init_full() override; - - llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; - - bool get_can_shift() const override; - - void clear(bool data) override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // state write/load - - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; - - // - // llama_kv_cache_unified specific API - // - - uint32_t get_size() const; - uint32_t get_n_stream() const; - - bool get_has_shift() const; - - // - // graph_build API - // - - uint32_t get_n_kv() const; - - // TODO: temporary - bool get_supports_set_rows() const; - - // get views of the current state of the cache - ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; - ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; - - // store k_cur and v_cur in the cache based on the provided head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const; - - // - // preparation API - // - - // find places for the provided ubatches in the cache, returns the slot infos - // return empty vector on failure - slot_info_vec_t prepare(const std::vector & ubatches); - - bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info); - - // find a slot of kv cells that can hold the ubatch - // if cont == true, then the slot must be continuous - // return empty slot_info on failure - slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; - - // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] - void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); - - // - // input API - // - - ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; - ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; - - void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; - void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; - - void set_input_k_shift(ggml_tensor * dst) const; - - void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; - void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; - -private: - const llama_model & model; - const llama_hparams & hparams; - - struct kv_layer { - // layer index in the model - // note: can be different from the layer index in the KV cache - uint32_t il; - - ggml_tensor * k; - ggml_tensor * v; - - std::vector k_stream; - std::vector v_stream; - }; - - bool v_trans = true; // the value tensor is transposed - - const uint32_t n_seq_max = 1; - const uint32_t n_stream = 1; - - // required padding - const uint32_t n_pad = 1; - - // SWA - const uint32_t n_swa = 0; - - // env: LLAMA_KV_CACHE_DEBUG - int debug = 0; - - // env: LLAMA_SET_ROWS (temporary) - // ref: https://github.com/ggml-org/llama.cpp/pull/14285 - bool supports_set_rows = true; - - const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; - - std::vector ctxs; - std::vector bufs; - - // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) - // note: this is not part of the KV state and it's only used to speed-up the find_slot() method - std::vector v_heads; - - std::vector v_cells; - - // maps from a sequence id to a stream id - std::vector seq_to_stream; - - // pending stream copies that will be applied during the next update - stream_copy_info sc_info; - - std::vector layers; - - // model layer id -> KV cache layer id - std::unordered_map map_layer_ids; - - // return non-empty vector if cells have been moved - defrag_info defrag_prepare(int32_t n_max_nodes) const; - - size_t total_size() const; - - size_t size_k_bytes() const; - size_t size_v_bytes() const; - - bool is_masked_swa(llama_pos p0, llama_pos p1) const; - - ggml_tensor * build_rope_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const; - - ggml_cgraph * build_graph_shift( - llm_graph_result * res, - llama_context * lctx) const; - - ggml_cgraph * build_graph_defrag( - llm_graph_result * res, - llama_context * lctx, - const defrag_info & dinfo) const; - - struct cell_ranges_t { - uint32_t strm; - - std::vector> data; // ranges, from inclusive, to exclusive - }; - - void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; - void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; - - bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); -}; - -class llama_kv_cache_unified_context : public llama_memory_context_i { -public: - // some shorthands - using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; - using defrag_info = llama_kv_cache_unified::defrag_info; - using stream_copy_info = llama_kv_cache_unified::stream_copy_info; - - // used for errors - llama_kv_cache_unified_context(llama_memory_status status); - - // used to create a full-cache context - llama_kv_cache_unified_context( - llama_kv_cache_unified * kv); - - // used to create an update context - llama_kv_cache_unified_context( - llama_kv_cache_unified * kv, - llama_context * lctx, - bool do_shift, - defrag_info dinfo, - stream_copy_info sc_info); - - // used to create a batch procesing context from a batch - llama_kv_cache_unified_context( - llama_kv_cache_unified * kv, - slot_info_vec_t sinfos, - std::vector ubatches); - - virtual ~llama_kv_cache_unified_context(); - - // - // llama_memory_context_i - // - - bool next() override; - bool apply() override; - - llama_memory_status get_status() const override; - const llama_ubatch & get_ubatch() const override; - - // - // llama_kv_cache_unified_context specific API - // - - uint32_t get_n_kv() const; - - // TODO: temporary - bool get_supports_set_rows() const; - - // get views of the current state of the cache - ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; - ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; - - // store k_cur and v_cur in the cache based on the provided head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; - - ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; - ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; - - void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; - void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; - - void set_input_k_shift (ggml_tensor * dst) const; - void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; - void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; - -private: - llama_memory_status status; - - llama_kv_cache_unified * kv; - llama_context * lctx; - - // - // update context - // - - bool do_shift = false; - - defrag_info dinfo; - - stream_copy_info sc_info; - - // - // batch processing context - // - - // the index of the cur ubatch to process - size_t i_cur = 0; - - slot_info_vec_t sinfos; - - std::vector ubatches; - - // - // data needed for building the compute graph for the current ubatch: - // - - // a heuristic, to avoid attending the full cache if it is not yet utilized - // as the cache gets filled, the benefit from this heuristic disappears - int32_t n_kv; -}; diff --git a/examples/talk-llama/llama-kv-cache-unified.cpp b/examples/talk-llama/llama-kv-cache.cpp similarity index 70% rename from examples/talk-llama/llama-kv-cache-unified.cpp rename to examples/talk-llama/llama-kv-cache.cpp index 478ebffac0f..885be072a75 100644 --- a/examples/talk-llama/llama-kv-cache-unified.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -1,4 +1,4 @@ -#include "llama-kv-cache-unified.h" +#include "llama-kv-cache.h" #include "llama-impl.h" #include "llama-io.h" @@ -13,36 +13,29 @@ #include // -// llama_kv_cache_unified +// llama_kv_cache // -llama_kv_cache_unified::llama_kv_cache_unified( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - bool unified, - uint32_t kv_size, - uint32_t n_seq_max, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type) : +llama_kv_cache::llama_kv_cache( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse) : model(model), hparams(model.hparams), v_trans(v_trans), n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { GGML_ASSERT(kv_size % n_pad == 0); - // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] - auto n_layer_cache = hparams.n_layer; - if (model.arch == LLM_ARCH_GEMMA3N) { - n_layer_cache = 20; - } - if (model.arch == LLM_ARCH_GLM4_MOE) { - // GLM-4.5: Only process up to last layer, skip final NextN layer - n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers; - } + const uint32_t n_layer_kv = hparams.n_layer_kv(); // create a context for each buffer type std::map ctx_map; @@ -50,7 +43,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -97,9 +90,14 @@ llama_kv_cache_unified::llama_kv_cache_unified( __func__, hparams.n_embd_v_gqa_max()); } - for (uint32_t il = 0; il < n_layer_cache; il++) { + for (uint32_t il = 0; il < hparams.n_layer; il++) { + if (!hparams.has_kv(il)) { + LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); + continue; + } + if (filter && !filter(il)) { - LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); + LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il); continue; } @@ -147,23 +145,27 @@ llama_kv_cache_unified::llama_kv_cache_unified( layers.push_back({ il, k, v, k_stream, v_stream, }); } - // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] - if (model.arch == LLM_ARCH_GEMMA3N) { - LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1); + if (reuse) { + LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__); - for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) { - if (filter && !filter(il)) { - LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); + for (uint32_t il = 0; il < hparams.n_layer; il++) { + const int32_t il_reuse = reuse(il); + + if (il_reuse < 0) { + LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il); continue; } - const bool is_swa = hparams.is_swa(il); - const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1); + if (filter && !filter(il)) { + LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il); + continue; + } GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end()); + map_layer_ids[il] = map_layer_ids[il_reuse]; - LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa); + LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il)); } } @@ -195,21 +197,9 @@ llama_kv_cache_unified::llama_kv_cache_unified( const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; - - const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); - supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : supports_set_rows; - - if (!supports_set_rows) { - // ref: https://github.com/ggml-org/llama.cpp/pull/14363 - GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support"); - } - - if (!supports_set_rows) { - LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__); - } } -void llama_kv_cache_unified::clear(bool data) { +void llama_kv_cache::clear(bool data) { for (uint32_t s = 0; s < n_stream; ++s) { v_cells[s].reset(); v_heads[s] = 0; @@ -222,7 +212,7 @@ void llama_kv_cache_unified::clear(bool data) { } } -bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { +bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); if (p0 < 0) { @@ -285,7 +275,7 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos return true; } -void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size()); GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size()); @@ -368,7 +358,7 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id //} } -void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { +void llama_kv_cache::seq_keep(llama_seq_id seq_id) { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -390,7 +380,7 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { } } -void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { +void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -434,7 +424,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po head = new_head != cells.size() ? new_head : 0; } -void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { +void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -467,7 +457,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po } } -llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { +llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -475,7 +465,7 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { return cells.seq_pos_min(seq_id); } -llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { +llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -483,7 +473,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { return cells.seq_pos_max(seq_id); } -llama_memory_context_ptr llama_kv_cache_unified::init_batch( +llama_memory_context_ptr llama_kv_cache::init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { @@ -513,62 +503,34 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( break; } - return std::make_unique( + return std::make_unique( this, std::move(sinfos), std::move(ubatches)); } while (false); - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } -llama_memory_context_ptr llama_kv_cache_unified::init_full() { - return std::make_unique(this); +llama_memory_context_ptr llama_kv_cache::init_full() { + return std::make_unique(this); } -llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { - bool do_shift = get_has_shift(); - - defrag_info dinfo; - - // see if we need to defrag - if (n_stream == 1) { - // note : for now do not consider defrag for n_stream > 1 - const auto & cells = v_cells[seq_to_stream[0]]; - - bool do_defrag = optimize; - - const auto thold = lctx->get_cparams().defrag_thold; +llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) { + GGML_UNUSED(optimize); - if (!do_defrag && thold > 0.0f) { - const auto n_kv = cells.used_max_p1(); - - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; - - if (fragmentation > thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - do_defrag = true; - } - } - - if (do_defrag) { - dinfo = defrag_prepare(lctx->graph_max_nodes()); - } - } + bool do_shift = get_has_shift(); - return std::make_unique(this, lctx, do_shift, std::move(dinfo), std::move(sc_info)); + return std::make_unique(this, lctx, do_shift, std::move(sc_info)); } -llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector & ubatches) { - llama_kv_cache_unified::slot_info_vec_t res; +llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector & ubatches) { + llama_kv_cache::slot_info_vec_t res; struct state_t { slot_info sinfo; // slot info for the ubatch std::vector v_heads_old; // old positions of the heads, before placing the ubatch - std::vector v_cells; // copy of the old cells, before placing the ubatch + std::vector v_cells; // copy of the old cells, before placing the ubatch }; // remember the old state of the cells so we can restore it in the end @@ -577,11 +539,8 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st bool success = true; for (const auto & ubatch : ubatches) { - // non-continuous slots require support for ggml_set_rows() - const bool cont = supports_set_rows ? false : true; - // only find a suitable slot for the ubatch. don't modify the cells yet - const auto sinfo_new = find_slot(ubatch, cont); + const auto sinfo_new = find_slot(ubatch, false); if (sinfo_new.empty()) { success = false; break; @@ -629,7 +588,7 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st return res; } -bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) { +bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) { bool updated = false; auto * sched = lctx->get_sched(); @@ -699,57 +658,10 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d } } - if (!dinfo.empty()) { - LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - - // note: for now do not consider defrag for n_stream > 1 - auto & cells = v_cells[seq_to_stream[0]]; - auto & head = v_heads[seq_to_stream[0]]; - - // apply moves: - { - const auto n_kv = dinfo.ids.size(); - - for (uint32_t i = 0; i < n_kv; ++i) { - assert(dinfo.ids[i] <= n_kv); - - if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) { - continue; - } - - cells.mv(i, dinfo.ids[i]); - } - - // reset the head so we can find the first free slot during the next ubatch - head = 0; - } - - ggml_backend_sched_reset(sched); - - auto * res = lctx->get_gf_res_reserve(); - - res->reset(); - - auto * gf = build_graph_defrag(res, lctx, dinfo); - if (!ggml_backend_sched_alloc_graph(sched, gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); - return updated; - } - - res->set_inputs(nullptr); - - if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { - LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); - return updated; - } - - updated = true; - } - return updated; } -llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { +llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const { if (debug > 0) { for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { @@ -844,8 +756,8 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id); } - res.s0 = std::min(res.s0, seq_to_stream[seq_id]); - res.s1 = std::max(res.s1, seq_to_stream[seq_id]); + res.s0 = std::min(res.s0, seq_to_stream[seq_id]); + res.s1 = std::max(res.s1, seq_to_stream[seq_id]); res.strm[s] = seq_to_stream[seq_id]; res.idxs[s].reserve(n_tokens); @@ -948,7 +860,7 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ return res; } -void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { +void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; @@ -1013,21 +925,21 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u } } -bool llama_kv_cache_unified::get_can_shift() const { +bool llama_kv_cache::get_can_shift() const { return true; } -uint32_t llama_kv_cache_unified::get_size() const { +uint32_t llama_kv_cache::get_size() const { const auto & cells = v_cells[seq_to_stream[0]]; return cells.size(); } -uint32_t llama_kv_cache_unified::get_n_stream() const { +uint32_t llama_kv_cache::get_n_stream() const { return n_stream; } -bool llama_kv_cache_unified::get_has_shift() const { +bool llama_kv_cache::get_has_shift() const { bool result = false; for (uint32_t s = 0; s < n_stream; ++s) { @@ -1037,11 +949,11 @@ bool llama_kv_cache_unified::get_has_shift() const { return result; } -uint32_t llama_kv_cache_unified::get_n_kv() const { +uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t result = 0; - for (uint32_t s = 0; s < n_stream; ++s) { - const auto & cells = v_cells[s]; + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const auto & cells = v_cells[sinfo.strm[s]]; result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result); } @@ -1049,11 +961,7 @@ uint32_t llama_kv_cache_unified::get_n_kv() const { return result; } -bool llama_kv_cache_unified::get_supports_set_rows() const { - return supports_set_rows; -} - -ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { +ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; @@ -1073,7 +981,7 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0); } -ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { +ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; @@ -1090,106 +998,113 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint // note: v->nb[1] <= v->nb[2] return ggml_view_4d(ctx, v, hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] - ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] + ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0); } // note: v->nb[1] > v->nb[2] return ggml_view_4d(ctx, v, n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns, - ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, kv_size), // v->nb[2] - ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] + ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, kv_size), // v->nb[2] + ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0); } -ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { +ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { + GGML_UNUSED(sinfo); + const int32_t ikv = map_layer_ids.at(il); - auto * k = layers[ikv].k; + ggml_tensor * k = layers[ikv].k; - const int64_t n_embd_k_gqa = k->ne[0]; - const int64_t n_tokens = k_cur->ne[2]; + const int64_t n_embd_head = k_cur->ne[0]; + const int64_t n_head = k_cur->ne[1]; + const int64_t n_tokens = k_cur->ne[2]; - k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); + const int64_t n_embd_gqa = n_embd_head*n_head; - if (k_idxs && supports_set_rows) { - if (k->ne[2] > 1) { - k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); - } + // we can merge dims 0 and 1 + // TODO: add ggml helper function for this? + GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]); - return ggml_set_rows(ctx, k, k_cur, k_idxs); - } + k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0); - // TODO: fallback to old ggml_cpy() method for backwards compatibility - // will be removed when ggml_set_rows() is adopted by all backends + const int64_t n_stream = k->ne[2]; - GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); + if (n_stream > 1) { + const int64_t kv_size = get_size(); - ggml_tensor * k_view = ggml_view_1d(ctx, k, - n_tokens*n_embd_k_gqa, - ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head()); + assert(n_embd_gqa == k->ne[0]); + assert(kv_size == k->ne[1]); + + // merge the buffer across all streams because the idxs are global + k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream); + } - return ggml_cpy(ctx, k_cur, k_view); + // store the current K values into the cache + return ggml_set_rows(ctx, k, k_cur, k_idxs); } -ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const { +ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const { + GGML_UNUSED(sinfo); + const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; - const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1]; - const int64_t n_tokens = v_cur->ne[2]; - - v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); - - if (v_idxs && supports_set_rows) { - if (!v_trans) { - if (v->ne[2] > 1) { - v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); - } + const int64_t n_embd_head = v_cur->ne[0]; + const int64_t n_head = v_cur->ne[1]; + const int64_t n_tokens = v_cur->ne[2]; - return ggml_set_rows(ctx, v, v_cur, v_idxs); - } + const int64_t n_embd_gqa = n_embd_head*n_head; - // [TAG_V_CACHE_VARIABLE] - if (n_embd_v_gqa < v->ne[0]) { - v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0); - } + // we can merge dims 0 and 1 + GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]); - // the row becomes a single element - ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]); + const int64_t n_stream = v->ne[2]; - v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]); + // take this branch when FA is enabled (the V cache is not transposed) + if (!v_trans) { + v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0); - return ggml_set_rows(ctx, v_view, v_cur, v_idxs); - } + if (n_stream > 1) { + const int64_t kv_size = get_size(); - // TODO: fallback to old ggml_cpy() method for backwards compatibility - // will be removed when ggml_set_rows() is adopted by all backends + assert(n_embd_gqa == v->ne[0]); + assert(kv_size == v->ne[1]); - GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); + // merge the buffer across all streams because the idxs are global + v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream); + } - ggml_tensor * v_view = nullptr; + return ggml_set_rows(ctx, v, v_cur, v_idxs); + } - if (!v_trans) { - v_view = ggml_view_1d(ctx, v, - n_tokens*n_embd_v_gqa, - ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head()); + if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) { + // we can merge dims 0, 1 and 2 + v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens); } else { - v_cur = ggml_transpose(ctx, v_cur); + // otherwise -> make a copy to get contiguous data + v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens); + } - v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa, - (v->ne[1] )*ggml_element_size(v), - (sinfo.head())*ggml_element_size(v)); + // [TAG_V_CACHE_VARIABLE] + if (n_embd_gqa < v->ne[0]) { + v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0); } - return ggml_cpy(ctx, v_cur, v_view); + // in this branch the v_idxs are constructed in such a way that each row is a single head element + ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v)); + + v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur)); + + return ggml_set_rows(ctx, v_view, v_cur, v_idxs); } -ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { +ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); @@ -1199,7 +1114,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con return k_idxs; } -ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { +ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; ggml_tensor * v_idxs; @@ -1215,11 +1130,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, con return v_idxs; } -void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { - if (!supports_set_rows) { - return; - } - +void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); @@ -1235,11 +1146,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba } } -void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { - if (!supports_set_rows) { - return; - } - +void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); @@ -1272,7 +1179,7 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba } } -void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { +void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const { GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int32_t * data = (int32_t *) dst->data; @@ -1286,7 +1193,7 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { } } -void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { +void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); @@ -1358,7 +1265,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub } } -void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { +void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { const int64_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams"); @@ -1383,7 +1290,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama } } -size_t llama_kv_cache_unified::total_size() const { +size_t llama_kv_cache::total_size() const { size_t size = 0; for (const auto & buf : bufs) { @@ -1393,7 +1300,7 @@ size_t llama_kv_cache_unified::total_size() const { return size; } -size_t llama_kv_cache_unified::size_k_bytes() const { +size_t llama_kv_cache::size_k_bytes() const { size_t size_k_bytes = 0; for (const auto & layer : layers) { @@ -1403,7 +1310,7 @@ size_t llama_kv_cache_unified::size_k_bytes() const { return size_k_bytes; } -size_t llama_kv_cache_unified::size_v_bytes() const { +size_t llama_kv_cache::size_v_bytes() const { size_t size_v_bytes = 0; for (const auto & layer : layers) { @@ -1413,7 +1320,7 @@ size_t llama_kv_cache_unified::size_v_bytes() const { return size_v_bytes; } -ggml_tensor * llama_kv_cache_unified::build_rope_shift( +ggml_tensor * llama_kv_cache::build_rope_shift( const llama_cparams & cparams, ggml_context * ctx, ggml_tensor * cur, @@ -1465,14 +1372,14 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift( class llm_graph_input_k_shift : public llm_graph_input_i { public: - llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {} virtual ~llm_graph_input_k_shift() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * k_shift; // I32 [kv_size*n_stream] - const llama_kv_cache_unified * kv_self; + const llama_kv_cache * kv_self; }; void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { @@ -1483,7 +1390,7 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { } } -ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { +ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { auto * ctx = res->get_ctx(); auto * gf = res->get_gf(); @@ -1525,310 +1432,11 @@ ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, return gf; } -ggml_cgraph * llama_kv_cache_unified::build_graph_defrag( - llm_graph_result * res, - llama_context * lctx, - const defrag_info & dinfo) const { - auto * ctx = res->get_ctx(); - auto * gf = res->get_gf(); - - GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); - - const auto & cells = v_cells[0]; - - const auto & ids = dinfo.ids; - - const auto & cparams = lctx->get_cparams(); - -#if 0 - // CPU defrag - // - // TODO: optimizations are possible: - // - multiple threads - // - avoid copying to the host memory when already there - // - // likely not worth the effort, as we have ggml_graph based defrag - // - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - - const uint32_t kv_size = size; - - std::vector buf_k; - std::vector buf_v; - - for (uint32_t il = 0; il < n_layer; ++il) { - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); - - const size_t v_size_el = ggml_type_size(v_l[il]->type); - const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); - - buf_k.resize(k_size); - buf_v.resize(v_size); - - ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); - - // batch move [i, i+nm) to [id, id+nm) - // note: cells can move only to a lower index - for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == n_kv) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < n_kv && ids[i + nm] == id + nm) { - nm++; - } - - // move keys - { - const int64_t os = i*k_size_row; - const int64_t od = id*k_size_row; - - memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); - } - - // move values (note: they are transposed) - { - const int64_t os = i; - const int64_t od = id; - - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); - } - } - - i += nm - 1; - } - - ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); - } -#else - for (uint32_t i = 0; i < ids.size(); ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == ids.size()) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < ids.size() && ids[i + nm] == id + nm) { - nm++; - } - - for (const auto & layer : layers) { - const uint32_t il = layer.il; - - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k, - n_embd_k_gqa, nm, - ggml_row_size(layer.k->type, n_embd_k_gqa), - ggml_row_size(layer.k->type, n_embd_k_gqa*i)); - - ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k, - n_embd_k_gqa, nm, - ggml_row_size(layer.k->type, n_embd_k_gqa), - ggml_row_size(layer.k->type, n_embd_k_gqa*id)); - - ggml_tensor * view_v_src; - ggml_tensor * view_v_dst; - - if (cparams.flash_attn) { - // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx, layer.v, - n_embd_v_gqa, nm, - ggml_row_size(layer.v->type, n_embd_v_gqa), - ggml_row_size(layer.v->type, n_embd_v_gqa*i)); - - view_v_dst = ggml_view_2d(ctx, layer.v, - n_embd_v_gqa, nm, - ggml_row_size(layer.v->type, n_embd_v_gqa), - ggml_row_size(layer.v->type, n_embd_v_gqa*id)); - } else { - view_v_src = ggml_view_2d(ctx, layer.v, - nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, cells.size()), - ggml_row_size(layer.v->type, i)); - - view_v_dst = ggml_view_2d(ctx, layer.v, - nm, n_embd_v_gqa, - ggml_row_size(layer.v->type, cells.size()), - ggml_row_size(layer.v->type, id)); - } - - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); - } - - i += nm - 1; - } - - //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); -#endif - - return gf; -} - -llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { - GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); - - const auto & cells = v_cells[0]; - - const uint32_t n_layer = layers.size(); - - const uint32_t n_kv = cells.used_max_p1(); - const uint32_t n_used = cells.get_used(); - - assert(n_used <= n_kv); - - //const int64_t t_start = ggml_time_us(); - - // number of cells moved - uint32_t n_moves = 0; - - // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) - // - source view, destination view, copy operation - // - x2 for keys and values - //const uint32_t max_moves = max_nodes()/(6*n_layer); - // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 - const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); - - // determine which KV cells to move where - defrag_info res; - auto & ids = res.ids; - - ids.resize(n_kv, n_kv); - - for (uint32_t i0 = 0; i0 < n_used; ++i0) { - if (!cells.is_empty(i0)) { - ids[i0] = i0; - - continue; - } - - // found a hole - fill it with data from the end of the cache - - uint32_t nh = 1; - - // determine the size of the hole - while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { - nh++; - } - - uint32_t nf = 0; - uint32_t is = n_kv - 1; - - // starting from the end, find nh non-empty cells - for (; is > i0; --is) { - if (cells.is_empty(is) || ids[is] != n_kv) { - continue; - } - - // non-empty cell which is not yet moved - nf++; - - if (nf == nh) { - break; - } - } - - // this can only happen if `n_used` is not accurate, which would be a bug - GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh"); - - nf = 0; - - uint32_t i1 = is; - - // are we moving a continuous block of memory? - bool cont = false; - - // should we stop searching for the next move? - bool stop = false; - - // go back and move the nf cells to the hole - for (; i1 < n_kv; ++i1) { - if (cells.is_empty(i1) || ids[i1] != n_kv) { - if (n_moves == max_moves) { - stop = true; - break; - } - - cont = false; - continue; - } - - // this cell goes to (i0 + nf) - ids[i1] = i0 + nf; - - if (!cont) { - n_moves++; - cont = true; - } - - nf++; - - if (nf == nh) { - break; - } - } - - if (stop || n_moves == max_moves) { - break; - } - - //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); - - i0 += nh - 1; - } - - if (n_moves == 0) { - return {}; - } - - LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); - - LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); - - return res; -} - -bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { - assert(p0 >= 0 && p1 >= 0); - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - } - - return false; +bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { + return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1); } -void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { +void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { GGML_UNUSED(flags); io.write(&n_stream, sizeof(n_stream)); @@ -1881,7 +1489,7 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq } } -void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { +void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { GGML_UNUSED(flags); GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); @@ -1917,7 +1525,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i } } -void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const { +void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const { const auto & cells = v_cells[cr.strm]; for (const auto & range : cr.data) { @@ -1945,7 +1553,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ } } -void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const { +void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const { const auto & cells = v_cells[cr.strm]; const uint32_t v_trans = this->v_trans ? 1 : 0; @@ -2040,7 +1648,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ } } -bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) { +bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) { auto & cells = v_cells[strm]; auto & head = v_heads[strm]; @@ -2137,7 +1745,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm return true; } -bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) { +bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) { auto & cells = v_cells[strm]; auto & head = v_heads[strm]; @@ -2274,13 +1882,13 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm } // -// llama_kv_cache_unified_context +// llama_kv_cache_context // -llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {} +llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {} -llama_kv_cache_unified_context::llama_kv_cache_unified_context( - llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { +llama_kv_cache_context::llama_kv_cache_context( + llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { n_kv = kv->get_size(); const uint32_t n_stream = kv->get_n_stream(); @@ -2296,26 +1904,25 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( } } -llama_kv_cache_unified_context::llama_kv_cache_unified_context( - llama_kv_cache_unified * kv, +llama_kv_cache_context::llama_kv_cache_context( + llama_kv_cache * kv, llama_context * lctx, bool do_shift, - defrag_info dinfo, - stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) { - if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) { + stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) { + if (!do_shift && this->sc_info.empty()) { status = LLAMA_MEMORY_STATUS_NO_UPDATE; } } -llama_kv_cache_unified_context::llama_kv_cache_unified_context( - llama_kv_cache_unified * kv, - llama_kv_cache_unified::slot_info_vec_t sinfos, +llama_kv_cache_context::llama_kv_cache_context( + llama_kv_cache * kv, + llama_kv_cache::slot_info_vec_t sinfos, std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) { } -llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; +llama_kv_cache_context::~llama_kv_cache_context() = default; -bool llama_kv_cache_unified_context::next() { +bool llama_kv_cache_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); if (++i_cur >= ubatches.size()) { @@ -2325,86 +1932,81 @@ bool llama_kv_cache_unified_context::next() { return true; } -bool llama_kv_cache_unified_context::apply() { +bool llama_kv_cache_context::apply() { assert(!llama_memory_status_is_fail(status)); // no ubatches -> this is a KV cache update if (ubatches.empty()) { - kv->update(lctx, do_shift, dinfo, sc_info); + kv->update(lctx, do_shift, sc_info); return true; } kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]); - - n_kv = kv->get_n_kv(); + n_kv = kv->get_n_kv(sinfos[i_cur]); return true; } -llama_memory_status llama_kv_cache_unified_context::get_status() const { +llama_memory_status llama_kv_cache_context::get_status() const { return status; } -const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const { +const llama_ubatch & llama_kv_cache_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); return ubatches[i_cur]; } -uint32_t llama_kv_cache_unified_context::get_n_kv() const { +uint32_t llama_kv_cache_context::get_n_kv() const { return n_kv; } -bool llama_kv_cache_unified_context::get_supports_set_rows() const { - return kv->get_supports_set_rows(); -} - -ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const { +ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const { return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } -ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const { +ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const { return kv->get_v(ctx, il, n_kv, sinfos[i_cur]); } -ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { +ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]); } -ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const { +ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const { return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]); } -ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { +ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { return kv->build_input_k_idxs(ctx, ubatch); } -ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { +ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { return kv->build_input_v_idxs(ctx, ubatch); } -void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const { +void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const { kv->set_input_k_shift(dst); } -void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { +void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]); } -void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { +void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]); } -void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { +void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { kv->set_input_kq_mask(dst, ubatch, causal_attn); } -void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { +void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } -uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { +uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) { // the FA kernels require padding to avoid extra runtime boundary checks return cparams.flash_attn ? 256u : 32u; } diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index 2d04705f278..30de013f5f7 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -1,44 +1,373 @@ #pragma once -#include "llama.h" -#include "llama-io.h" +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cells.h" #include "llama-memory.h" -struct llama_kv_cache : public llama_memory_i { - virtual ~llama_kv_cache() = default; +#include +#include - // split the input batch into a set of ubatches and verify that they can fit into the cache - // return a state object containing the ubatches and KV cache state required to process them - // check the llama_memory_state_i::get_status() for the result - virtual llama_memory_state_ptr init_batch( - const llama_batch & batch, +struct llama_cparams; +struct llama_hparams; +struct llama_model; +struct llama_context; + +// +// llama_kv_cache +// + +class llama_kv_cache : public llama_memory_i { +public: + static uint32_t get_padding(const llama_cparams & cparams); + + struct stream_copy_info { + bool empty() const { + assert(ssrc.size() == sdst.size()); + return ssrc.empty(); + } + + std::vector ssrc; + std::vector sdst; + }; + + // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the + // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] + struct slot_info { + // data for ggml_set_rows + using idx_vec_t = std::vector; + + // number of streams: ns = s1 - s0 + 1 + uint32_t s0; + uint32_t s1; + + std::vector strm; // [ns] + std::vector idxs; // [ns] + + uint32_t head() const { + GGML_ASSERT(idxs.size() == 1); + GGML_ASSERT(!idxs[0].empty()); + + return idxs[0][0]; + } + + void resize(size_t n) { + strm.resize(n); + idxs.resize(n); + } + + size_t size() const { + GGML_ASSERT(idxs.size() == strm.size()); + GGML_ASSERT(!idxs.empty()); + + return idxs[0].size(); + } + + size_t n_stream() const { + return strm.size(); + } + + bool empty() const { + return idxs.empty(); + } + + void clear() { + idxs.clear(); + } + }; + + using slot_info_vec_t = std::vector; + + llama_kv_cache( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse); + + ~llama_kv_cache() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) = 0; + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; - // simulate full cache, used for allocating worst-case compute buffers - virtual llama_memory_state_ptr init_full() = 0; + // state write/load - // process any pending defrag/shift/etc. operations - // optionally call once before processing a new batch - // return true if any operations were performed - virtual bool update(llama_context & lctx) = 0; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; - // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing - // TODO: change to - // llama_memory_state_ptr init_defrag(float thold) = 0; // - virtual void defrag_sched(float thold) = 0; + // llama_kv_cache specific API + // + + uint32_t get_size() const; + uint32_t get_n_stream() const; + + bool get_has_shift() const; + + // + // graph_build API + // + + uint32_t get_n_kv(const slot_info & sinfo) const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; + + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const; + + // + // preparation API + // + + // find places for the provided ubatches in the cache, returns the slot infos + // return empty vector on failure + slot_info_vec_t prepare(const std::vector & ubatches); + + bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info); + + // find a slot of kv cells that can hold the ubatch + // if cont == true, then the slot must be continuous + // return empty slot_info on failure + slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; + + // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] + void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); + + // + // input API + // + + ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; + void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; + + void set_input_k_shift(ggml_tensor * dst) const; + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + const llama_model & model; + const llama_hparams & hparams; + + struct kv_layer { + // layer index in the model + // note: can be different from the layer index in the KV cache + uint32_t il; + + ggml_tensor * k; + ggml_tensor * v; + + std::vector k_stream; + std::vector v_stream; + }; + + bool v_trans = true; // the value tensor is transposed + + const uint32_t n_seq_max = 1; + const uint32_t n_stream = 1; + + // required padding + const uint32_t n_pad = 1; + + // SWA + const uint32_t n_swa = 0; + + // env: LLAMA_KV_CACHE_DEBUG + int debug = 0; + + // this is the SWA type of the cache - not to be confused with the model SWA type + const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; + + std::vector ctxs; + std::vector bufs; + + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + std::vector v_heads; + + std::vector v_cells; + + // maps from a sequence id to a stream id + std::vector seq_to_stream; + + // pending stream copies that will be applied during the next update + stream_copy_info sc_info; + + std::vector layers; + + // model layer id -> KV cache layer id + std::unordered_map map_layer_ids; + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + + bool is_masked_swa(llama_pos p0, llama_pos p1) const; + + ggml_tensor * build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const; + + ggml_cgraph * build_graph_shift( + llm_graph_result * res, + llama_context * lctx) const; + + struct cell_ranges_t { + uint32_t strm; + + std::vector> data; // ranges, from inclusive, to exclusive + }; + + void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); +}; + +class llama_kv_cache_context : public llama_memory_context_i { +public: + // some shorthands + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + using stream_copy_info = llama_kv_cache::stream_copy_info; + + // used for errors + llama_kv_cache_context(llama_memory_status status); + + // used to create a full-cache context + llama_kv_cache_context( + llama_kv_cache * kv); + + // used to create an update context + llama_kv_cache_context( + llama_kv_cache * kv, + llama_context * lctx, + bool do_shift, + stream_copy_info sc_info); + + // used to create a batch procesing context from a batch + llama_kv_cache_context( + llama_kv_cache * kv, + slot_info_vec_t sinfos, + std::vector ubatches); + + virtual ~llama_kv_cache_context(); + + // + // llama_memory_context_i + // + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_context specific API + // + + uint32_t get_n_kv() const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; + + // store k_cur and v_cur in the cache based on the provided head location + // note: the heads in k_cur and v_cur should be layed out contiguously in memory + // - k_cur [n_embd_head_k, n_head_k, n_tokens] + // - k_idxs [n_tokens] + // - v_cur [n_embd_head_v, n_head_v, n_tokens] + // - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; + + // create destination indices for each head of the current batch for where it would be written in the KV cache + // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but + // helps understand the implementation logic of cpy_k and cpy_v + ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; + + void set_input_k_shift (ggml_tensor * dst) const; + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + llama_memory_status status; + + llama_kv_cache * kv; + llama_context * lctx; + + // + // update context + // + + bool do_shift = false; + + stream_copy_info sc_info; + + // + // batch processing context + // + + // the index of the cur ubatch to process + size_t i_cur = 0; - // getters - virtual bool get_can_shift() const = 0; + slot_info_vec_t sinfos; - bool get_can_edit() const override { return get_can_shift(); } + std::vector ubatches; // - // state write/read + // data needed for building the compute graph for the current ubatch: // - virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; - virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; + // a heuristic, to avoid attending the full cache if it is not yet utilized + // as the cache gets filled, the benefit from this heuristic disappears + int32_t n_kv; }; diff --git a/examples/talk-llama/llama-kv-cells.h b/examples/talk-llama/llama-kv-cells.h index 0d0dd316fd0..8f6bf01456c 100644 --- a/examples/talk-llama/llama-kv-cells.h +++ b/examples/talk-llama/llama-kv-cells.h @@ -11,7 +11,7 @@ // meta information about KV cells that can be part of multiple sequences at the same time // TODO: add unit tests -class llama_kv_cells_unified { +class llama_kv_cells { public: void reset() { for (uint32_t i = 0; i < pos.size(); ++i) { @@ -77,30 +77,30 @@ class llama_kv_cells_unified { } // move cell isrc to idst (used during defrag) - void mv(uint32_t isrc, uint32_t idst) { - assert(isrc < pos.size()); - assert(idst < pos.size()); + //void mv(uint32_t isrc, uint32_t idst) { + // assert(isrc < pos.size()); + // assert(idst < pos.size()); - assert(pos[idst] == -1); - assert(pos[isrc] != -1); + // assert(pos[idst] == -1); + // assert(pos[isrc] != -1); - pos [idst] = pos [isrc]; - shift[idst] = shift[isrc]; - seq [idst] = seq [isrc]; + // pos [idst] = pos [isrc]; + // shift[idst] = shift[isrc]; + // seq [idst] = seq [isrc]; - pos [isrc] = -1; - shift[isrc] = 0; - seq [isrc].reset(); + // pos [isrc] = -1; + // shift[isrc] = 0; + // seq [isrc].reset(); - used.erase (isrc); - used.insert(idst); - } + // used.erase (isrc); + // used.insert(idst); + //} // copy the state of cells [i, i + n) (used for save/restore the state of the cells) - llama_kv_cells_unified cp(uint32_t i, uint32_t n) const { + llama_kv_cells cp(uint32_t i, uint32_t n) const { assert(i + n <= pos.size()); - llama_kv_cells_unified res; + llama_kv_cells res; res.resize(n); @@ -117,8 +117,8 @@ class llama_kv_cells_unified { } // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) - llama_kv_cells_unified cp(const std::vector & idxs) const { - llama_kv_cells_unified res; + llama_kv_cells cp(const std::vector & idxs) const { + llama_kv_cells res; res.resize(idxs.size()); @@ -135,7 +135,7 @@ class llama_kv_cells_unified { } // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells) - void set(uint32_t i, const llama_kv_cells_unified & other) { + void set(uint32_t i, const llama_kv_cells & other) { assert(i + other.pos.size() <= pos.size()); for (uint32_t j = 0; j < other.pos.size(); ++j) { @@ -165,7 +165,7 @@ class llama_kv_cells_unified { } // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) - void set(const std::vector & idxs, const llama_kv_cells_unified & other) { + void set(const std::vector & idxs, const llama_kv_cells & other) { assert(idxs.size() == other.pos.size()); for (uint32_t j = 0; j < other.pos.size(); ++j) { diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index cbeeb21344e..ba61ebaa885 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -9,32 +9,29 @@ // llama_memory_hybrid::llama_memory_hybrid( - const llama_model & model, - /* attn */ - ggml_type type_k, - ggml_type type_v, - bool v_trans, - uint32_t kv_size, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type, - /* recurrent */ - ggml_type type_r, - ggml_type type_s, - uint32_t rs_size, - /* common */ - uint32_t n_seq_max, - bool offload, - bool unified, - /* layer filters */ - layer_filter_cb && filter_attn, - layer_filter_cb && filter_recr) : + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + uint32_t kv_size, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn, + const layer_filter_cb & filter_recr) : hparams(model.hparams), - mem_attn(new llama_kv_cache_unified( + mem_attn(new llama_kv_cache( model, - filter_attn == nullptr ? - [&](int32_t il) { return !hparams.is_recurrent(il); } - : filter_attn, type_k, type_v, v_trans, @@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid( n_seq_max, n_pad, n_swa, - swa_type + swa_type, + filter_attn == nullptr ? + [&](int32_t il) { return !hparams.is_recurrent(il); } + : filter_attn, + nullptr )), mem_recr(new llama_memory_recurrent( model, - filter_recr == nullptr ? - [&](int32_t il) { return hparams.is_recurrent(il); } - : filter_recr, type_r, type_s, offload, rs_size, - n_seq_max + n_seq_max, + filter_recr == nullptr ? + [&](int32_t il) { return hparams.is_recurrent(il); } + : filter_recr )) {} llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { @@ -179,7 +180,7 @@ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, mem_recr->state_read(io, seq_id); } -llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const { +llama_kv_cache * llama_memory_hybrid::get_mem_attn() const { return mem_attn.get(); } @@ -210,7 +211,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( std::vector ubatches) : ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal - ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), + ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } @@ -248,8 +249,8 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const { return ubatches[i_next]; } -const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const { - return static_cast(ctx_attn.get()); +const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const { + return static_cast(ctx_attn.get()); } const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const { diff --git a/examples/talk-llama/llama-memory-hybrid.h b/examples/talk-llama/llama-memory-hybrid.h index acdbc26bfb6..11a35651782 100644 --- a/examples/talk-llama/llama-memory-hybrid.h +++ b/examples/talk-llama/llama-memory-hybrid.h @@ -2,7 +2,7 @@ #include "llama-batch.h" #include "llama-graph.h" -#include "llama-kv-cache-unified.h" +#include "llama-kv-cache.h" #include "llama-memory.h" #include "llama-memory-recurrent.h" @@ -13,36 +13,32 @@ // llama_memory_hybrid // -// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to +// utilizes instances of llama_memory_recurrent and llama_kv_cache to // support models where each layer may be either attention-based or recurrent class llama_memory_hybrid : public llama_memory_i { public: - - // this callback is used to filter out layers that should not be included in the cache - using layer_filter_cb = std::function; - llama_memory_hybrid( const llama_model & model, /* attn */ - ggml_type type_k, - ggml_type type_v, - bool v_trans, - uint32_t kv_size, - uint32_t n_pad, - uint32_t n_swa, - llama_swa_type swa_type, - /* recurrent */ - ggml_type type_r, - ggml_type type_s, - uint32_t rs_size, - /* common */ - uint32_t n_seq_max, - bool offload, - bool unified, - /* layer filters */ - layer_filter_cb && filter_attn = nullptr, - layer_filter_cb && filter_recr = nullptr); + ggml_type type_k, + ggml_type type_v, + bool v_trans, + uint32_t kv_size, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn = nullptr, + const layer_filter_cb & filter_recr = nullptr); ~llama_memory_hybrid() = default; @@ -81,19 +77,19 @@ class llama_memory_hybrid : public llama_memory_i { // llama_memory_hybrid specific API // - llama_kv_cache_unified * get_mem_attn() const; + llama_kv_cache * get_mem_attn() const; llama_memory_recurrent * get_mem_recr() const; private: const llama_hparams & hparams; - const std::unique_ptr mem_attn; + const std::unique_ptr mem_attn; const std::unique_ptr mem_recr; }; class llama_memory_hybrid_context : public llama_memory_context_i { public: - using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; // init failure explicit llama_memory_hybrid_context(llama_memory_status status); @@ -125,7 +121,7 @@ class llama_memory_hybrid_context : public llama_memory_context_i { // llama_memory_hybrid_context // - const llama_kv_cache_unified_context * get_attn() const; + const llama_kv_cache_context * get_attn() const; const llama_memory_recurrent_context * get_recr() const; private: diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 849675c4188..08716ed91ae 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -16,13 +16,13 @@ // llama_memory_recurrent::llama_memory_recurrent( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_r, - ggml_type type_s, - bool offload, - uint32_t mem_size, - uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { + const llama_model & model, + ggml_type type_r, + ggml_type type_s, + bool offload, + uint32_t mem_size, + uint32_t n_seq_max, + const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; head = 0; diff --git a/examples/talk-llama/llama-memory-recurrent.h b/examples/talk-llama/llama-memory-recurrent.h index 95c617b2c94..c4daf00495b 100644 --- a/examples/talk-llama/llama-memory-recurrent.h +++ b/examples/talk-llama/llama-memory-recurrent.h @@ -12,21 +12,17 @@ // // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i -// see the implementation of llama_kv_cache_unified_context_i for an example how to do it +// see the implementation of llama_kv_cache_context_i for an example how to do it class llama_memory_recurrent : public llama_memory_i { public: - - // this callback is used to filter out layers that should not be included in the cache - using layer_filter_cb = std::function; - llama_memory_recurrent( - const llama_model & model, - layer_filter_cb && filter, - ggml_type type_r, - ggml_type type_s, - bool offload, - uint32_t mem_size, - uint32_t n_seq_max); + const llama_model & model, + ggml_type type_r, + ggml_type type_s, + bool offload, + uint32_t mem_size, + uint32_t n_seq_max, + const layer_filter_cb & filter); ~llama_memory_recurrent() = default; diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index 171d312cc99..ccd1f073b08 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include struct llama_ubatch; @@ -36,8 +37,8 @@ bool llama_memory_status_is_fail(llama_memory_status status); // the interface for managing the memory context during batch processing // this interface is implemented per memory type. see: -// - llama_kv_cache_unified_context -// - llama_kv_cache_unified_iswa_context +// - llama_kv_cache_context +// - llama_kv_cache_iswa_context // ... // // the only method that should mutate the memory and the memory context is llama_memory_i::apply() @@ -64,6 +65,13 @@ using llama_memory_context_ptr = std::unique_ptr; // general concept of LLM memory // the KV cache is a type of LLM memory, but there can be other types struct llama_memory_i { + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + + // this callback is used to specify which layers should reuse memory from other layers + // return negative value to indicate that the layer il should not reuse memory + using layer_reuse_cb = std::function; + virtual ~llama_memory_i() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache @@ -77,7 +85,7 @@ struct llama_memory_i { // simulate full cache, used for allocating worst-case compute buffers virtual llama_memory_context_ptr init_full() = 0; - // prepare for any pending memory updates, such as shifts, defrags, etc. + // prepare for any pending memory updates, such as shifts, copies, etc. // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0; @@ -109,8 +117,3 @@ struct llama_memory_i { }; using llama_memory_ptr = std::unique_ptr; - -// TODO: temporary until the llama_kv_cache is removed from the public API -struct llama_kv_cache : public llama_memory_i { - virtual ~llama_kv_cache() = default; -}; diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index f71c40f8e3f..8182a9adf53 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -788,6 +788,7 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri } struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags) { + LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str()); const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); if (cur == NULL) { diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 23a26f0c64e..981e57083c4 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -6,8 +6,8 @@ #include "llama-cparams.h" #include "llama-model-loader.h" -#include "llama-kv-cache-unified.h" -#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache.h" +#include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" #include "llama-memory-recurrent.h" @@ -36,6 +36,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_80M: return "80M"; case LLM_TYPE_109M: return "109M"; case LLM_TYPE_137M: return "137M"; + case LLM_TYPE_140M: return "140M"; case LLM_TYPE_160M: return "160M"; case LLM_TYPE_190M: return "190M"; case LLM_TYPE_220M: return "220M"; @@ -44,12 +45,15 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_270M: return "270M"; case LLM_TYPE_335M: return "335M"; case LLM_TYPE_350M: return "350M"; + case LLM_TYPE_360M: return "360M"; case LLM_TYPE_410M: return "410M"; case LLM_TYPE_450M: return "450M"; case LLM_TYPE_475M: return "475M"; + case LLM_TYPE_558M: return "558M"; case LLM_TYPE_700M: return "700M"; case LLM_TYPE_770M: return "770M"; case LLM_TYPE_780M: return "780M"; + case LLM_TYPE_950M: return "950M"; case LLM_TYPE_0_3B: return "0.3B"; case LLM_TYPE_0_5B: return "0.5B"; case LLM_TYPE_0_6B: return "0.6B"; @@ -83,9 +87,11 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_32B: return "32B"; case LLM_TYPE_34B: return "34B"; case LLM_TYPE_35B: return "35B"; + case LLM_TYPE_36B: return "36B"; case LLM_TYPE_40B: return "40B"; case LLM_TYPE_65B: return "65B"; case LLM_TYPE_70B: return "70B"; + case LLM_TYPE_120B: return "120B"; case LLM_TYPE_142B: return "142B"; case LLM_TYPE_236B: return "236B"; case LLM_TYPE_290B: return "290B"; @@ -619,19 +625,32 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; - hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick - hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa == 0) { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope + } else { + hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; + hparams.n_swa = 8192; + hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full + } switch (hparams.n_expert) { + case 0: { + // MobileLLM (no MoE) + switch (hparams.n_embd) { + case 2048: type = LLM_TYPE_140M; break; + case 4096: type = LLM_TYPE_360M; break; + case 6144: type = LLM_TYPE_950M; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case 16: type = LLM_TYPE_17B_16E; break; case 128: type = LLM_TYPE_17B_128E; break; default: type = LLM_TYPE_UNKNOWN; } - if (type == LLM_TYPE_17B_128E) { - hparams.use_kq_norm = false; - } + hparams.use_kq_norm = type != LLM_TYPE_17B_128E; } break; case LLM_ARCH_ARCEE: { @@ -682,7 +701,30 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GROK: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // defaults for old GGUFs + hparams.yarn_beta_fast = 8.0f; + hparams.f_logit_scale = 0.5773502691896257f; + hparams.f_embedding_scale = 78.38367176906169f; + hparams.f_attn_out_scale = 0.08838834764831845f; + hparams.f_attn_logit_softcapping = 30.0f; + hparams.f_router_logit_softcapping = 30.0f; + // no final_logit_softcapping in grok-1 + hparams.f_final_logit_softcapping = 0.0f; + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); switch (hparams.n_layer) { case 64: type = LLM_TYPE_314B; break; @@ -770,6 +812,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_JINA_BERT_V3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + switch (hparams.n_layer) { + case 24: + type = LLM_TYPE_558M; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: { @@ -898,6 +952,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.causal_attn = false; } break; + case LLM_ARCH_LLADA_MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // diffusion language model uses non-causal attention + hparams.causal_attn = false; + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_A1_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_QWEN2MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); @@ -1095,7 +1161,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 18: type = LLM_TYPE_537M; break; + case 18: type = LLM_TYPE_270M; break; case 26: type = LLM_TYPE_1B; break; case 34: type = LLM_TYPE_4B; break; case 48: type = LLM_TYPE_12B; break; @@ -1113,6 +1179,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(5); + hparams.n_layer_kv_from_start = 20; hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; hparams.f_attention_scale = 1.0f; @@ -1126,6 +1193,26 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GEMMA_EMBEDDING: + { + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + hparams.set_swa_pattern(6); + + hparams.causal_attn = false; // embeddings do not use causal attention + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_0_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1279,6 +1366,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(4); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + switch (hparams.n_layer) { case 16: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_7B; break; @@ -1287,6 +1382,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_SEED_OSS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 64: type = LLM_TYPE_36B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_OLMOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1464,12 +1567,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { // Expert gating function (GLM-4.5 uses sigmoid) ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; } // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + switch (hparams.n_layer) { case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) @@ -1495,6 +1601,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.dec_start_token_id = dec_start_token_id; } + hparams.dec_n_layer = hparams.n_layer; + ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false); + switch (hparams.n_layer) { case 6: type = LLM_TYPE_60M; break; // t5-small case 8: type = LLM_TYPE_80M; break; // flan-t5-small @@ -1543,6 +1652,27 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_NEMOTRON_H: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // A layer is recurrent IFF the n_head_kv value is set to 0 and + // the n_ff value is set to 0 + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 56: type = LLM_TYPE_9B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_EXAONE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1834,7 +1964,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(2); - // TODO: switch (hparams.n_layer) + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_20B; break; + case 36: type = LLM_TYPE_120B; break; + default: type = LLM_TYPE_UNKNOWN; + } } break; case LLM_ARCH_LFM2: { @@ -2289,6 +2423,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; + case LLM_ARCH_LLADA_MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for llada-moe"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for llada-moe"); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } + } break; case LLM_ARCH_LLAMA4: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -2302,9 +2470,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0"); for (int i = 0; i < n_layer; ++i) { - bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0; + bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0; auto & layer = layers[i]; @@ -2465,6 +2632,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -2479,12 +2647,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + if (!layer.ffn_post_norm) { + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } } } break; case LLM_ARCH_DBRX: @@ -2613,6 +2788,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_JINA_BERT_V3: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); @@ -2648,24 +2824,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); } else { - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - - if (arch == LLM_ARCH_BERT || arch == LLM_ARCH_NOMIC_BERT_MOE) { - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); } } @@ -3433,6 +3607,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_GEMMA3: + case LLM_ARCH_GEMMA_EMBEDDING: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3962,6 +4137,43 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_SEED_OSS: + { + const uint32_t head_dim = hparams.n_embd_head_k; + const int64_t n_qo_dim = n_head * head_dim; + const int64_t n_kv_dim = n_head_kv * head_dim; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0); + + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + } + } break; + case LLM_ARCH_OLMOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4305,6 +4517,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + // n_layer: number of encoder_layers + // dec_n_layer: number of decoder_layers + const int dec_n_layer = hparams.dec_n_layer; + if (dec_n_layer > n_layer) { + layers.resize(dec_n_layer); + } + + // load encoder layers for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4320,6 +4540,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // load decoder layers + for (int i = 0; i < dec_n_layer; ++i) { + auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); @@ -4621,6 +4846,75 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_NEMOTRON_H: + { + // mamba2 Mixer SSM params + // NOTE: int64_t for tensor dimensions + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_ssm_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // all blocks use the attn norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.is_recurrent(i)) { + // ssm layers + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else if (hparams.n_ff(i) == 0) { + // attention layers (with optional bias) + const int64_t n_head_i = hparams.n_head(i); + const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + } else { + // mlp layers + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_EXAONE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5469,8 +5763,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_LFM2: { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -5790,7 +6089,8 @@ void llama_model::print_info() const { arch == LLM_ARCH_JAMBA || arch == LLM_ARCH_FALCON_H1 || arch == LLM_ARCH_PLAMO2 || - arch == LLM_ARCH_GRANITE_HYBRID) { + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); @@ -5981,7 +6281,7 @@ struct llm_build_llama : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -6043,9 +6343,17 @@ struct llm_build_llama : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); + if (hparams.use_kq_norm) { + // Llama4TextL2Norm + Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); + Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + } + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -6141,7 +6449,7 @@ struct llm_build_llama_iswa : public llm_graph_context { ggml_tensor * inp_attn_scale = nullptr; inp_attn_scale = build_inp_attn_scale(); - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + auto * inp_attn = build_attn_inp_kv_iswa(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -6150,7 +6458,8 @@ struct llm_build_llama_iswa : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0; + const bool use_rope = hparams.n_no_rope_layer_step > 0 && + (il + 1) % hparams.n_no_rope_layer_step != 0; // norm cur = build_norm(inpL, @@ -6219,7 +6528,7 @@ struct llm_build_llama_iswa : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -6320,7 +6629,7 @@ struct llm_build_deci : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -6396,7 +6705,7 @@ struct llm_build_deci : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -6476,7 +6785,7 @@ struct llm_build_baichuan : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr; - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -6528,7 +6837,7 @@ struct llm_build_baichuan : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -6598,7 +6907,7 @@ struct llm_build_xverse : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -6643,7 +6952,7 @@ struct llm_build_xverse : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -6712,7 +7021,7 @@ struct llm_build_falcon : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -6743,9 +7052,7 @@ struct llm_build_falcon : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -6766,7 +7073,7 @@ struct llm_build_falcon : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -6830,13 +7137,10 @@ struct llm_build_grok : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); - // multiply by embedding_multiplier_scale of 78.38367176906169 - inpL = ggml_scale(ctx0, inpL, 78.38367176906169f); - // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -6896,7 +7200,7 @@ struct llm_build_grok : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -6904,26 +7208,22 @@ struct llm_build_grok : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // Grok - // if attn_out_norm is present then apply it before adding the input - if (model.layers[il].attn_out_norm) { - cur = build_norm(cur, - model.layers[il].attn_out_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_out_norm", il); - } + cur = build_norm(cur, + model.layers[il].attn_out_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_out_norm", il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); // feed-forward network - // MoE branch cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - cur = build_moe_ffn(cur, + // MoE branch + ggml_tensor * moe_out = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -6934,18 +7234,28 @@ struct llm_build_grok : public llm_graph_context { false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); - cb(cur, "ffn_moe_out", il); + cb(moe_out, "ffn_moe_out", il); - // Grok - // if layer_out_norm is present then apply it before adding the input - // Idea: maybe ffn_out_norm is a better name - if (model.layers[il].layer_out_norm) { - cur = build_norm(cur, - model.layers[il].layer_out_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "layer_out_norm", il); + if (model.layers[il].ffn_up) { + ggml_tensor * ffn_out = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(ffn_out, "ffn_out", il); + + cur = ggml_scale(ctx0, ggml_add(ctx0, ffn_out, moe_out), std::sqrt(2) / 2); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; } + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_post_norm", il); + cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -6968,10 +7278,14 @@ struct llm_build_grok : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // Grok - // multiply logits by output_multiplier_scale of 0.5773502691896257 + cur = ggml_scale(ctx0, cur, hparams.f_logit_scale); - cur = ggml_scale(ctx0, cur, 0.5773502691896257f); + // final logit soft-capping + if (hparams.f_final_logit_softcapping) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } cb(cur, "result_output", -1); res->t_logits = cur; @@ -6996,7 +7310,7 @@ struct llm_build_dbrx : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -7023,9 +7337,7 @@ struct llm_build_dbrx : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -7045,7 +7357,7 @@ struct llm_build_dbrx : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -7120,7 +7432,7 @@ struct llm_build_starcoder : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); cb(pos, "pos_embd", -1); @@ -7145,13 +7457,9 @@ struct llm_build_starcoder : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7159,7 +7467,7 @@ struct llm_build_starcoder : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -7225,7 +7533,7 @@ struct llm_build_refact : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -7258,7 +7566,7 @@ struct llm_build_refact : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -7367,13 +7675,17 @@ struct llm_build_bert : public llm_graph_context { cb(cur, "bqkv", il); } - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } if (model.layers[il].attn_q_norm) { @@ -7381,6 +7693,8 @@ struct llm_build_bert : public llm_graph_context { model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); } if (model.layers[il].attn_k_norm) { @@ -7388,14 +7702,12 @@ struct llm_build_bert : public llm_graph_context { model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + } // RoPE - if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { + if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) { Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -7415,7 +7727,7 @@ struct llm_build_bert : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } @@ -7454,7 +7766,7 @@ struct llm_build_bert : public llm_graph_context { 0.0f, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); - } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { + } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) { cur = build_ffn(cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, @@ -7537,9 +7849,7 @@ struct llm_build_neo_bert : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); // RoPE Qcur = ggml_rope_ext( @@ -7560,7 +7870,7 @@ struct llm_build_neo_bert : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } @@ -7621,7 +7931,7 @@ struct llm_build_bloom : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); inpL = build_norm(inpL, model.tok_norm, @@ -7646,13 +7956,9 @@ struct llm_build_bloom : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7660,7 +7966,7 @@ struct llm_build_bloom : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -7728,7 +8034,7 @@ struct llm_build_mpt : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); if (model.pos_embd) { // inp_pos - contains the positions @@ -7768,13 +8074,9 @@ struct llm_build_mpt : public llm_graph_context { cb(cur, "wqkv_clamped", il); } - ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); // Q/K Layernorm if (model.layers[il].attn_q_norm) { @@ -7782,32 +8084,23 @@ struct llm_build_mpt : public llm_graph_context { model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); - cb(Qcur, "Qcur", il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - cb(Kcur, "Kcur", il); - } else { - Qcur = ggml_cont(ctx0, Qcur); - cb(Qcur, "Qcur", il); - Kcur = ggml_cont(ctx0, Kcur); - cb(Kcur, "Kcur", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -7877,7 +8170,7 @@ struct llm_build_stablelm : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -7953,7 +8246,7 @@ struct llm_build_stablelm : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8029,7 +8322,7 @@ struct llm_build_qwen : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -8049,11 +8342,9 @@ struct llm_build_qwen : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd))); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 2*sizeof(float)*(n_embd)); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -8074,7 +8365,7 @@ struct llm_build_qwen : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8144,7 +8435,7 @@ struct llm_build_qwen2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -8194,7 +8485,7 @@ struct llm_build_qwen2 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8308,8 +8599,9 @@ struct llm_build_dream : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, - nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8408,8 +8700,9 @@ struct llm_build_llada : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, - 1.0f / sqrtf(float(n_embd_head)), il); + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8469,7 +8762,7 @@ struct llm_build_qwen2vl : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); @@ -8522,7 +8815,7 @@ struct llm_build_qwen2vl : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8590,7 +8883,7 @@ struct llm_build_qwen2moe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -8649,7 +8942,7 @@ struct llm_build_qwen2moe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8749,7 +9042,7 @@ struct llm_build_qwen3 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -8802,7 +9095,7 @@ struct llm_build_qwen3 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -8870,7 +9163,7 @@ struct llm_build_qwen3moe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -8923,7 +9216,7 @@ struct llm_build_qwen3moe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9000,7 +9293,7 @@ struct llm_build_phi2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -9026,21 +9319,17 @@ struct llm_build_phi2 : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9063,7 +9352,7 @@ struct llm_build_phi2 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9129,13 +9418,13 @@ struct llm_build_phi3 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - using inp_attn_type = std::conditional_t; + using inp_attn_type = std::conditional_t; inp_attn_type * inp_attn = nullptr; if constexpr (iswa) { - inp_attn = build_attn_inp_kv_unified_iswa(); + inp_attn = build_attn_inp_kv_iswa(); } else { - inp_attn = build_attn_inp_kv_unified(); + inp_attn = build_attn_inp_kv(); } ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -9164,21 +9453,17 @@ struct llm_build_phi3 : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9200,7 +9485,7 @@ struct llm_build_phi3 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9287,7 +9572,7 @@ struct llm_build_plamo : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -9334,7 +9619,7 @@ struct llm_build_plamo : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9403,7 +9688,7 @@ struct llm_build_gpt2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); cb(pos, "pos_embd", -1); @@ -9428,21 +9713,17 @@ struct llm_build_gpt2 : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9513,7 +9794,7 @@ struct llm_build_codeshell : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -9534,9 +9815,7 @@ struct llm_build_codeshell : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -9556,7 +9835,7 @@ struct llm_build_codeshell : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9626,7 +9905,7 @@ struct llm_build_orion : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -9685,7 +9964,7 @@ struct llm_build_orion : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9753,7 +10032,7 @@ struct llm_build_internlm2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -9812,7 +10091,7 @@ struct llm_build_internlm2 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -9889,7 +10168,7 @@ struct llm_build_minicpm3 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -10000,7 +10279,7 @@ struct llm_build_minicpm3 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); + q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -10084,7 +10363,7 @@ struct llm_build_gemma : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -10130,7 +10409,7 @@ struct llm_build_gemma : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -10200,7 +10479,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -10245,7 +10524,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -10334,7 +10613,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context { ggml_tensor * inp_pos = build_inp_pos(); // TODO: is causal == true correct? might need some changes - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -10387,7 +10666,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -10459,7 +10738,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { const int64_t n_embd_altup; const int64_t n_altup; const int i_altup_act; - const int n_layer_kv = 20; // number of layers having KV [KV_REUSE] const int n_layer_sparsity = 10; // number of layers using activation sparsity const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95) @@ -10485,7 +10763,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * inp_pos = build_inp_pos(); // TODO: is causal == true correct? might need some changes - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + auto * inp_attn = build_attn_inp_kv_iswa(); // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer] ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs()); @@ -10509,8 +10787,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { // this block is made to be closely resemble Gemma3p5DecoderLayer on python code - const bool has_kv = (il < n_layer_kv); - const float freq_base_l = model.get_rope_freq_base (cparams, il); const float freq_scale_l = model.get_rope_freq_scale(cparams, il); @@ -10530,7 +10806,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens] // self-attention - if (has_kv) { + if (hparams.has_kv(il)) { // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); @@ -10568,9 +10844,9 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } else { - // no KV layers + // reuse KV cache of earlier layers ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); @@ -10586,7 +10862,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } cur = build_norm(cur, @@ -10864,8 +11140,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens] all_coefs = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f); // + 1.0 cb(all_coefs, "all_coefs", il); - all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup] - all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] + all_coefs = ggml_transpose(ctx0, all_coefs); // [n_tokens, n_altup] + all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1); ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup] @@ -10876,6 +11152,137 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { } }; +struct llm_build_gemma_embedding_iswa : public llm_graph_context { + llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + if (ubatch.token) { + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + } + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA] + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315 + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = build_norm(sa_out, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + // TODO: move up next to build_starcoder struct llm_build_starcoder2 : public llm_graph_context { llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { @@ -10892,7 +11299,7 @@ struct llm_build_starcoder2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -10951,7 +11358,7 @@ struct llm_build_starcoder2 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -11378,7 +11785,9 @@ struct llm_build_jamba : public llm_graph_context_mamba { cb(Vcur, "Vcur", il); // No RoPE :) - cur = build_attn(inp_hybrid->get_attn(), model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); + cur = build_attn(inp_hybrid->get_attn(), + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -11461,7 +11870,7 @@ struct llm_build_command_r : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -11536,7 +11945,7 @@ struct llm_build_command_r : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -11608,7 +12017,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -11671,7 +12080,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -11743,7 +12152,7 @@ struct llm_build_olmo : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -11802,7 +12211,7 @@ struct llm_build_olmo : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -11856,6 +12265,7 @@ struct llm_build_olmo : public llm_graph_context { } }; +template struct llm_build_olmo2 : public llm_graph_context { llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -11871,7 +12281,14 @@ struct llm_build_olmo2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -11904,17 +12321,36 @@ struct llm_build_olmo2 : public llm_graph_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( + const bool is_swa = hparams.is_swa(il); + + if (is_swa) { + // For sliding window layers, Olmo3 use regular rope with no yarn rope scaling. + // This is achieved here by setting freq_scale and attn_factor to 1. + // We also set ext_factor to 0 to avoid a few unnecessary computations. + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_ext( + Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -11922,7 +12358,7 @@ struct llm_build_olmo2 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -12000,7 +12436,7 @@ struct llm_build_olmoe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -12055,7 +12491,7 @@ struct llm_build_olmoe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -12113,30 +12549,27 @@ struct llm_build_olmoe : public llm_graph_context { } }; -struct llm_build_openelm : public llm_graph_context { - llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +struct llm_build_llada_moe : public llm_graph_context { + llm_build_llada_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); ggml_tensor * cur; ggml_tensor * inpL; + inpL = build_inp_embd(model.tok_embd); // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_no_cache(); ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { - const int64_t n_head = hparams.n_head(il); - const int64_t n_head_kv = hparams.n_head_kv(il); - const int64_t n_head_qkv = 2*n_head_kv + n_head; - - cur = inpL; - ggml_tensor * residual = cur; + ggml_tensor * inpSA = inpL; // norm cur = build_norm(inpL, @@ -12144,26 +12577,155 @@ struct llm_build_openelm : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - // self-attention + // self_attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0); + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv))); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = build_norm(Qcur, - model.layers[il].attn_q_norm, NULL, - LLM_NORM_RMS, il); - cb(Qcur, "Qcur", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_openelm : public llm_graph_context { + llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + const int64_t n_head = hparams.n_head(il); + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_head_qkv = 2*n_head_kv + n_head; + + cur = inpL; + ggml_tensor * residual = cur; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv))); + cb(Vcur, "Vcur", il); + + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, NULL, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur", il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, @@ -12188,7 +12750,7 @@ struct llm_build_openelm : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -12257,7 +12819,7 @@ struct llm_build_gptneox : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -12278,9 +12840,7 @@ struct llm_build_gptneox : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -12300,7 +12860,7 @@ struct llm_build_gptneox : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -12403,7 +12963,7 @@ struct llm_build_arctic : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -12450,7 +13010,7 @@ struct llm_build_arctic : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -12541,7 +13101,7 @@ struct llm_build_deepseek : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -12605,7 +13165,7 @@ struct llm_build_deepseek : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -12718,7 +13278,7 @@ struct llm_build_deepseek2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -12833,7 +13393,7 @@ struct llm_build_deepseek2 : public llm_graph_context { // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il); } else { ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); cb(kv, "kv", il); @@ -12867,7 +13427,7 @@ struct llm_build_deepseek2 : public llm_graph_context { // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } } @@ -12965,7 +13525,7 @@ struct llm_build_bitnet : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -13034,7 +13594,7 @@ struct llm_build_bitnet : public llm_graph_context { cur = build_attn(inp_attn, NULL, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cur = build_norm(cur, model.layers[il].attn_sub_norm, NULL, @@ -13157,7 +13717,7 @@ struct llm_build_t5_enc : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo_enc, nullptr, - Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -13229,12 +13789,14 @@ struct llm_build_t5_dec : public llm_graph_context { const int64_t n_outputs_enc = embd_enc->ne[1]; - auto * inp_attn_self = build_attn_inp_kv_unified(); + auto * inp_attn_self = build_attn_inp_kv(); auto * inp_attn_cross = build_attn_inp_cross(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + const int64_t dec_n_layer = hparams.dec_n_layer; + + for (int il = 0; il < dec_n_layer; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -13263,7 +13825,7 @@ struct llm_build_t5_dec : public llm_graph_context { cur = build_attn(inp_attn_self, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -13295,7 +13857,7 @@ struct llm_build_t5_dec : public llm_graph_context { cur = build_attn(inp_attn_cross, model.layers[il].wo_cross, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); //ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); @@ -13325,7 +13887,7 @@ struct llm_build_t5_dec : public llm_graph_context { //cb(cur, "kqv_out", il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == dec_n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); } @@ -13346,8 +13908,8 @@ struct llm_build_t5_dec : public llm_graph_context { model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, - model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, - model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ, il); cb(cur, "ffn_out", il); } @@ -13394,7 +13956,7 @@ struct llm_build_jais : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -13413,21 +13975,17 @@ struct llm_build_jais : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*cur->nb[0]*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -13492,7 +14050,7 @@ struct llm_build_chatglm : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -13526,6 +14084,7 @@ struct llm_build_chatglm : public llm_graph_context { } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); @@ -13535,11 +14094,9 @@ struct llm_build_chatglm : public llm_graph_context { } Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -13559,7 +14116,7 @@ struct llm_build_chatglm : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -13625,7 +14182,7 @@ struct llm_build_glm4 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -13660,6 +14217,7 @@ struct llm_build_glm4 : public llm_graph_context { } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); @@ -13669,11 +14227,9 @@ struct llm_build_glm4 : public llm_graph_context { } Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); } - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -13692,7 +14248,7 @@ struct llm_build_glm4 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -13775,7 +14331,7 @@ struct llm_build_glm4_moe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -13841,7 +14397,7 @@ struct llm_build_glm4_moe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_transformer_layers - 1 && inp_out_ids) { @@ -13935,7 +14491,7 @@ struct llm_build_nemotron : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -13995,7 +14551,7 @@ struct llm_build_nemotron : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -14049,6 +14605,138 @@ struct llm_build_nemotron : public llm_graph_context { } }; +struct llm_build_nemotron_h : public llm_graph_context_mamba { + llm_build_nemotron_h( + const llama_model & model, + const llm_graph_params & params) : + llm_graph_context_mamba(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + if (hparams.is_recurrent(il)) { + // ssm layer // + cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); + } else if (hparams.n_ff(il) == 0) { + // attention layer // + cur = build_attention_layer(cur, inp->get_attn(), model, n_embd_head, il); + } else { + cur = build_ffn_layer(cur, model, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // add residual + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "block_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + + ggml_tensor * build_attention_layer( + ggml_tensor * cur, + llm_graph_input_attn_kv * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const int il) { + + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + return cur; + } + + ggml_tensor * build_ffn_layer( + ggml_tensor * cur, + const llama_model & model, + const int il) { + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + return cur; + } +}; + struct llm_build_exaone : public llm_graph_context { llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -14064,7 +14752,7 @@ struct llm_build_exaone : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -14126,7 +14814,7 @@ struct llm_build_exaone : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -14196,13 +14884,13 @@ struct llm_build_exaone4 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - using inp_attn_type = std::conditional_t; + using inp_attn_type = std::conditional_t; inp_attn_type * inp_attn = nullptr; if constexpr (iswa) { - inp_attn = build_attn_inp_kv_unified_iswa(); + inp_attn = build_attn_inp_kv_iswa(); } else { - inp_attn = build_attn_inp_kv_unified(); + inp_attn = build_attn_inp_kv(); } ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -14257,7 +14945,7 @@ struct llm_build_exaone4 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } @@ -15085,7 +15773,7 @@ struct llm_build_granite : public llm_graph_context { inp_pos = build_inp_pos(); } - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -15136,12 +15824,12 @@ struct llm_build_granite : public llm_graph_context { } ggml_tensor * build_attention_layer( - ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv_unified * inp_attn, - const llama_model & model, - const int64_t n_embd_head, - const int il) { + ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const int il) { // compute Q and K and (optionally) RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -15192,7 +15880,7 @@ struct llm_build_granite : public llm_graph_context { const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; } @@ -15355,12 +16043,12 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { } ggml_tensor * build_attention_layer( - ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv_unified * inp_attn, - const llama_model & model, - const int64_t n_embd_head, - const int il) { + ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const int il) { // compute Q and K and (optionally) RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -15411,7 +16099,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; } @@ -15517,7 +16205,7 @@ struct llm_build_chameleon : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -15596,7 +16284,7 @@ struct llm_build_chameleon : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -15848,7 +16536,7 @@ struct llm_build_plm : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -15952,7 +16640,7 @@ struct llm_build_plm : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); + q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -16013,7 +16701,7 @@ struct llm_build_bailingmoe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -16075,7 +16763,7 @@ struct llm_build_bailingmoe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -16162,7 +16850,7 @@ struct llm_build_dots1 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -16215,7 +16903,7 @@ struct llm_build_dots1 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -16312,7 +17000,7 @@ struct llm_build_ernie4_5 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -16370,7 +17058,7 @@ struct llm_build_ernie4_5 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -16442,7 +17130,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -16503,7 +17191,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } @@ -16656,7 +17344,7 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba { ggml_tensor * attn_out = build_attn(inp->get_attn(), model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); cur = build_norm(inpL, @@ -16816,7 +17504,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba { private: ggml_tensor * build_plamo2_attn_layer( - llm_graph_input_attn_kv_unified * inp, + llm_graph_input_attn_kv * inp, ggml_tensor * inp_pos, ggml_tensor * cur, const llama_model & model, @@ -16838,16 +17526,14 @@ struct llm_build_plamo2 : public llm_graph_context_mamba { const int64_t k_offset = n_embd_head_q * n_head; const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; - ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv))); + ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv, n_tokens, n_embd_head_v * sizeof(float), qkv->nb[1], v_offset * ggml_element_size(qkv)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -16866,7 +17552,9 @@ struct llm_build_plamo2 : public llm_graph_context_mamba { ext_factor, attn_factor, beta_fast, beta_slow ); - cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il); + cur = build_attn(inp, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il); } cb(cur, "attn_out", il); @@ -16913,15 +17601,13 @@ struct llm_build_plamo2 : public llm_graph_context_mamba { cb(zx, "mamba_in_proj", il); // {8192, 5, 1, 1} -> {8192, 1, 5, 1} zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); - zx = ggml_cont(ctx0, zx); - zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); + zx = ggml_cont_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); cb(zx, "mamba_in_proj_out", il); // split into z and x // => {head_dim * n_heads, n_seq_tokens, n_seqs} ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx)); - x = ggml_cont(ctx0, x); - x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); + x = ggml_cont_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); // x = ggml_permute(ctx0, x, 0, 2, 1, 3); cb(x, "mamba_x_split", il); @@ -17051,7 +17737,7 @@ struct llm_build_arcee : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -17115,7 +17801,7 @@ struct llm_build_arcee : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -17186,7 +17872,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); @@ -17260,7 +17946,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -17347,7 +18033,7 @@ struct llm_build_hunyuan_dense : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); @@ -17420,7 +18106,7 @@ struct llm_build_hunyuan_dense : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -17485,7 +18171,7 @@ struct llm_build_smollm3 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -17550,7 +18236,7 @@ struct llm_build_smollm3 : public llm_graph_context { cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -17617,7 +18303,7 @@ struct llm_build_openai_moe_iswa : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified_iswa(); + auto * inp_attn = build_attn_inp_kv_iswa(); for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -17672,9 +18358,9 @@ struct llm_build_openai_moe_iswa : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn_with_sinks(inp_attn, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].attn_sinks, 1.0f/sqrtf(float(n_rot)), il); + Qcur, Kcur, Vcur, nullptr, model.layers[il].attn_sinks, nullptr, 1.0f/sqrtf(float(n_rot)), il); cb(cur, "attn_out", il); } @@ -17771,8 +18457,7 @@ struct llm_build_lfm2 : public llm_graph_context { cb(cur, "model.embedding_norm", -1); res->t_embd = cur; - // lm_head is tied with embeddings - cur = build_lora_mm(model.tok_embd, cur); + cur = build_lora_mm(model.output, cur); cb(cur, "lm_head", -1); res->t_logits = cur; @@ -17799,10 +18484,10 @@ struct llm_build_lfm2 : public llm_graph_context { return cur; } - ggml_tensor * build_attn_block(ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv_unified * inp_attn, - int il) const { + ggml_tensor * build_attn_block(ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv * inp_attn, + int il) const { GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il)); auto const n_embd_head = hparams.n_embd_head_v; auto const n_head_kv = hparams.n_head_kv(il); @@ -17837,7 +18522,7 @@ struct llm_build_lfm2 : public llm_graph_context { ); cur = build_attn(inp_attn, model.layers[il].wo, NULL, - q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + q, k, v, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "model.layers.{}.self_attn.out_proj", il); @@ -17914,6 +18599,137 @@ struct llm_build_lfm2 : public llm_graph_context { } }; +struct llm_build_seed_oss : public llm_graph_context { + llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + template struct llm_build_smallthinker : public llm_graph_context{ llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ @@ -17930,13 +18746,13 @@ struct llm_build_smallthinker : public llm_graph_context{ // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - using inp_attn_type = std::conditional_t; + using inp_attn_type = std::conditional_t; inp_attn_type * inp_attn = nullptr; if constexpr (iswa) { - inp_attn = build_attn_inp_kv_unified_iswa(); + inp_attn = build_attn_inp_kv_iswa(); } else { - inp_attn = build_attn_inp_kv_unified(); + inp_attn = build_attn_inp_kv(); } ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -17981,7 +18797,7 @@ struct llm_build_smallthinker : public llm_graph_context{ cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -18043,12 +18859,15 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, // switch statement case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_JINA_BERT_V3: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: + //case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA] case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: + case LLM_ARCH_LLADA_MOE: { res = nullptr; } break; @@ -18059,14 +18878,31 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, if (llm_arch_is_recurrent(arch)) { res = new llama_memory_recurrent( *this, - nullptr, GGML_TYPE_F32, GGML_TYPE_F32, cparams.offload_kqv, std::max((uint32_t) 1, cparams.n_seq_max), - cparams.n_seq_max); + cparams.n_seq_max, + nullptr); } else if (llm_arch_is_hybrid(arch)) { - const auto padding = llama_kv_cache_unified::get_padding(cparams); + + // The main difference between hybrid architectures is the + // layer filters, so pick the right one here + llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; + llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; + if (arch == LLM_ARCH_FALCON_H1) { + filter_attn = [&](int32_t) { return true; }; + filter_recr = [&](int32_t) { return true; }; + } else if (arch == LLM_ARCH_NEMOTRON_H) { + filter_attn = [&](int32_t il) { + return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + }; + filter_recr = [&](int32_t il) { + return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + }; + } + + const auto padding = llama_kv_cache::get_padding(cparams); cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); @@ -18085,10 +18921,10 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* n_seq_max */ cparams.n_seq_max, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, - /* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr, - /* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr); + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); } else { - const auto padding = llama_kv_cache_unified::get_padding(cparams); + const auto padding = llama_kv_cache::get_padding(cparams); uint32_t n_ctx_per_stream = cparams.n_ctx; @@ -18105,10 +18941,22 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + llama_memory_i::layer_reuse_cb reuse = nullptr; + + if (arch == LLM_ARCH_GEMMA3N) { + reuse = [&](int32_t il) { + if (il >= (int32_t) hparams.n_layer_kv_from_start) { + return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); + } + + return -1; + }; + } + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); - res = new llama_kv_cache_unified_iswa( + res = new llama_kv_cache_iswa( *this, params.type_k, params.type_v, @@ -18119,13 +18967,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, n_ctx_per_stream, cparams.n_seq_max, cparams.n_ubatch, - padding); + padding, + nullptr, + reuse); } else { GGML_ASSERT(!hparams.is_swa_any()); - res = new llama_kv_cache_unified( + res = new llama_kv_cache( *this, - nullptr, params.type_k, params.type_v, !cparams.flash_attn, @@ -18135,7 +18984,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, padding, hparams.n_swa, - hparams.swa_type); + hparams.swa_type, + nullptr, + nullptr); } } } @@ -18154,7 +19005,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_LLAMA4: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { + llm = std::make_unique(*this, params); + } else { + llm = std::make_unique(*this, params); + } } break; case LLM_ARCH_DECI: { @@ -18182,6 +19037,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_JINA_BERT_V3: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: { @@ -18221,6 +19077,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_LLADA_MOE: + { + llm = std::make_unique(*this, params); + } + break; case LLM_ARCH_QWEN2VL: { llm = std::make_unique(*this, params); @@ -18294,6 +19155,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GEMMA_EMBEDDING: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_STARCODER2: { llm = std::make_unique(*this, params); @@ -18329,7 +19194,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_OLMO2: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_OLMOE: { @@ -18398,6 +19267,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_NEMOTRON_H: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_EXAONE: { llm = std::make_unique(*this, params); @@ -18452,6 +19325,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_SEED_OSS: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_DOTS1: { llm = std::make_unique(*this, params); @@ -18510,6 +19387,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { return llm->res->get_gf(); } + // // interface implementation // @@ -18518,7 +19396,7 @@ llama_model_params llama_model_default_params() { llama_model_params result = { /*.devices =*/ nullptr, /*.tensor_buft_overrides =*/ nullptr, - /*.n_gpu_layers =*/ 0, + /*.n_gpu_layers =*/ 999, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, @@ -18532,11 +19410,6 @@ llama_model_params llama_model_default_params() { /*.use_extra_bufts =*/ true, }; -#ifdef GGML_USE_METAL - // note: we usually have plenty of VRAM, so by default offload all layers to the GPU - result.n_gpu_layers = 999; -#endif - return result; } @@ -18628,6 +19501,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_NEMOTRON_H: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -18667,6 +19541,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GROK: case LLM_ARCH_DBRX: case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V3: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_STABLELM: @@ -18677,6 +19552,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN2MOE: case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3MOE: + case LLM_ARCH_LLADA_MOE: case LLM_ARCH_OLMO2: case LLM_ARCH_OLMOE: case LLM_ARCH_PHI2: @@ -18688,6 +19564,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: + case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: @@ -18704,6 +19581,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_LFM2: case LLM_ARCH_SMALLTHINKER: case LLM_ARCH_GLM4_MOE: + case LLM_ARCH_SEED_OSS: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 46f7d0480fa..b1981978e3a 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -28,6 +28,7 @@ enum llm_type { LLM_TYPE_80M, LLM_TYPE_109M, LLM_TYPE_137M, + LLM_TYPE_140M, LLM_TYPE_160M, LLM_TYPE_190M, LLM_TYPE_220M, @@ -36,13 +37,15 @@ enum llm_type { LLM_TYPE_270M, LLM_TYPE_335M, LLM_TYPE_350M, + LLM_TYPE_360M, LLM_TYPE_410M, LLM_TYPE_450M, LLM_TYPE_475M, - LLM_TYPE_537M, + LLM_TYPE_558M, LLM_TYPE_700M, LLM_TYPE_770M, LLM_TYPE_780M, + LLM_TYPE_950M, LLM_TYPE_0_3B, LLM_TYPE_0_5B, LLM_TYPE_0_6B, @@ -76,9 +79,11 @@ enum llm_type { LLM_TYPE_32B, LLM_TYPE_34B, LLM_TYPE_35B, + LLM_TYPE_36B, LLM_TYPE_40B, LLM_TYPE_65B, LLM_TYPE_70B, + LLM_TYPE_120B, LLM_TYPE_142B, LLM_TYPE_236B, LLM_TYPE_290B, diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 1d0361cc166..97228b2a693 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -725,7 +725,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // attention layers have a non-zero number of kv heads int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0); if (llama_model_has_encoder(&model)) { - n_attn_layer *= 3; + // now n_attn_layer is the number of attention layers in the encoder + // for each decoder block, there are 2 attention layers + n_attn_layer += 2 * model.hparams.dec_n_layer; } GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected"); } @@ -920,7 +922,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: new_type = tensor->type; new_data = tensor->data; new_size = ggml_nbytes(tensor); - LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0); + LLAMA_LOG_INFO("size = %8.3f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0); } else { const int64_t nelements = ggml_nelements(tensor); @@ -1037,8 +1039,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } close_ofstream(); - LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); - LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); + LLAMA_LOG_INFO("%s: model size = %8.2f MiB\n", __func__, total_size_org/1024.0/1024.0); + LLAMA_LOG_INFO("%s: quant size = %8.2f MiB\n", __func__, total_size_new/1024.0/1024.0); if (qs.n_fallback > 0) { LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n", diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampling.cpp index bfbf5fa2301..2186f827bf5 100644 --- a/examples/talk-llama/llama-sampling.cpp +++ b/examples/talk-llama/llama-sampling.cpp @@ -128,6 +128,89 @@ struct ring_buffer { std::vector data; }; +// writes result in res, does not mutate cur +static void llama_token_data_array_partial_sort(const llama_token_data_array & cur, int npartial, std::vector & res) { + static const auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + + constexpr int nbuckets = 128; + constexpr float bucket_low = -10.0f; + constexpr float bucket_high = 10.0f; + constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); + constexpr float bucket_inter = -bucket_low * bucket_scale; + + std::vector bucket_idx; + std::vector histo(nbuckets, 0); + + std::vector bucket_ptrs; + + bucket_idx.reserve(cur.size); + + for (int i = 0; i < (int)cur.size; ++i) { + const float val = cur.data[i].logit; + int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); + ib = std::max(0, std::min(nbuckets - 1, ib)); + bucket_idx.push_back(ib); + ++histo[ib]; + } + int nhave = 0; + int ib = nbuckets - 1; + for ( ; ib >= 0; --ib) { + nhave += histo[ib]; + if (nhave >= npartial) { + break; + } + } + res.resize(nhave); + auto * ptr = res.data(); + bucket_ptrs.reserve(nbuckets - ib); + for (int j = nbuckets - 1; j >= ib; --j) { + bucket_ptrs.push_back(ptr); + ptr += histo[j]; + } + for (int i = 0; i < (int)cur.size; ++i) { + int j = bucket_idx[i]; + if (j >= ib) { + *bucket_ptrs[nbuckets - 1 - j]++ = cur.data[i]; + } + } + + ptr = res.data(); + int ndone = 0; + for (int j = nbuckets - 1; j > ib; --j) { + std::sort(ptr, ptr + histo[j], comp); + ptr += histo[j]; + ndone += histo[j]; + } + std::partial_sort(ptr, ptr + npartial - ndone, ptr + histo[ib], comp); +} + +// reduces the size of cur_p to npartial, keeping only the top npartial elements +static void llama_token_data_array_partial_sort_inplace(llama_token_data_array * cur_p, int npartial) { + static const auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + + if (npartial <= 128) { + std::partial_sort(cur_p->data, cur_p->data + npartial, cur_p->data + cur_p->size, comp); + + cur_p->size = npartial; + cur_p->sorted = true; + + return; + } + + std::vector tmp; + + llama_token_data_array_partial_sort(*cur_p, npartial, tmp); + + std::copy(tmp.data(), tmp.data() + npartial, cur_p->data); + + cur_p->size = npartial; + cur_p->sorted = true; +} + static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) { // iterator for the probabilities #ifdef __GNUC__ @@ -200,18 +283,21 @@ static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) } } -static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { +static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_sort) { GGML_ASSERT(cur_p->size > 0); - // Sort the logits in descending order - if (!cur_p->sorted) { - std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - cur_p->sorted = true; + // Sort the logits in descending order if requested + if (do_sort && !cur_p->sorted) { + llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size); } float max_l = cur_p->data[0].logit; + if (!cur_p->sorted) { + for (size_t i = 1; i < cur_p->size; ++i) { + max_l = std::max(max_l, cur_p->data[i].logit); + } + } + float cum_sum = 0.0f; for (size_t i = 0; i < cur_p->size; ++i) { @@ -226,7 +312,6 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { } static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) { - // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast // if (k >= (int32_t)cur_p->size) { // return; // } @@ -239,64 +324,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) // Sort scores in descending order if (!cur_p->sorted) { - auto comp = [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }; - if (k <= 128) { - std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp); - } else { - constexpr int nbuckets = 128; - constexpr float bucket_low = -10.0f; - constexpr float bucket_high = 10.0f; - constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); - constexpr float bucket_inter = -bucket_low * bucket_scale; - - std::vector bucket_idx(cur_p->size); - std::vector histo(nbuckets, 0); - - for (int i = 0; i < (int)cur_p->size; ++i) { - const float val = cur_p->data[i].logit; - int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); - ib = std::max(0, std::min(nbuckets - 1, ib)); - bucket_idx[i] = ib; - ++histo[ib]; - } - int nhave = 0; - int ib = nbuckets - 1; - for ( ; ib >= 0; --ib) { - nhave += histo[ib]; - if (nhave >= k) { - break; - } - } - std::vector tmp_tokens(nhave); - auto * ptr = tmp_tokens.data(); - std::vector bucket_ptrs; - bucket_ptrs.reserve(nbuckets - ib); - for (int j = nbuckets - 1; j >= ib; --j) { - bucket_ptrs.push_back(ptr); - ptr += histo[j]; - } - for (int i = 0; i < (int)cur_p->size; ++i) { - int j = bucket_idx[i]; - if (j >= ib) { - *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i]; - } - } - - ptr = tmp_tokens.data(); - int ndone = 0; - for (int j = nbuckets - 1; j > ib; --j) { - std::sort(ptr, ptr + histo[j], comp); - ptr += histo[j]; - ndone += histo[j]; - } - std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); - - std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data)); - - } - cur_p->sorted = true; + llama_token_data_array_partial_sort_inplace(cur_p, k); } cur_p->size = k; @@ -576,9 +604,73 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl* static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_dist *) smpl->ctx; - llama_sampler_softmax_impl(cur_p); + // edge cases + if (cur_p->size == 0) { + cur_p->selected = -1; + return; + } + + cur_p->selected = 0; + + if (cur_p->size == 1) { + cur_p->data[0].p = 1.0f; + return; + } + + // max logit for numerical stability + float max_l = cur_p->data[0].logit; + if (!cur_p->sorted) { + for (size_t i = 1; i < cur_p->size; ++i) { + max_l = std::max(max_l, cur_p->data[i].logit); + } + } + + // apply softmax to obtain the probabilities + double sum_cum = 0.0f; + for (size_t i = 0; i < cur_p->size; ++i) { + float p = expf(cur_p->data[i].logit - max_l); + cur_p->data[i].p = p; + sum_cum += p; + } + +#if 1 + // sample from the obtained probabilities and normalize the probs in a single pass + // this is ~3x faster on Mac with full gpt-oss vocab than the version below + // + std::uniform_real_distribution dist(0.0f, 1.0f); + const double rnd = dist(ctx->rng); + + double sum_run = 0.0f; + const double sum_tgt = sum_cum*rnd; + + bool found = false; + for (size_t i = 0; i < cur_p->size; ++i) { + if (!found) { + // accumulate probs until we reach the target sum + sum_run += cur_p->data[i].p; + if (sum_run >= sum_tgt) { + cur_p->selected = i; + found = true; + } + } + + // normalize probs + cur_p->data[i].p /= sum_cum; + } + + // fallback to the last token (don't think this can happen) + assert(found); + if (!found) { + cur_p->selected = cur_p->size - 1; + } +#else + // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= sum_cum; + } cur_p->selected = llama_sample_dist(cur_p, ctx->rng); +#endif } static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { @@ -626,32 +718,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { ); } -// softmax - -static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) { - return "softmax"; -} - -static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { - llama_sampler_softmax_impl(cur_p); -} - -static struct llama_sampler_i llama_sampler_softmax_i = { - /* .name = */ llama_sampler_softmax_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_softmax_apply, - /* .reset = */ nullptr, - /* .clone = */ nullptr, - /* .free = */ nullptr, -}; - -struct llama_sampler * llama_sampler_init_softmax() { - return llama_sampler_init( - /* .iface = */ &llama_sampler_softmax_i, - /* .ctx = */ nullptr - ); -} - // top-k struct llama_sampler_top_k { @@ -663,7 +729,7 @@ static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl } static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_top_k *) smpl->ctx; + auto * ctx = (llama_sampler_top_k *) smpl->ctx; llama_sampler_top_k_impl(cur_p, ctx->k); } @@ -699,6 +765,8 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) { struct llama_sampler_top_p { const float p; const size_t min_keep; + + std::vector buf_sort; }; static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) { @@ -706,20 +774,35 @@ static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl } static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_top_p *) smpl->ctx; + auto * ctx = (llama_sampler_top_p *) smpl->ctx; if (ctx->p >= 1.0f) { return; } - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, false); + + size_t k = cur_p->size; + auto * pdata = cur_p->data; + + auto & buf_sort = ctx->buf_sort; + + // if not sorted, try adaptive top-k sorting + if (!cur_p->sorted && cur_p->size > 1024) { + k = std::min(256, cur_p->size); + llama_token_data_array_partial_sort(*cur_p, k, buf_sort); + pdata = buf_sort.data(); + } else if (!cur_p->sorted) { + // small candidates -> sort inplace + llama_token_data_array_partial_sort_inplace(cur_p, k); + } // Compute the cumulative probabilities float cum_sum = 0.0f; size_t last_idx = cur_p->size; for (size_t i = 0; i < cur_p->size; ++i) { - cum_sum += cur_p->data[i].p; + cum_sum += pdata[i].p; // Check if the running sum is at least p or if we have kept at least min_keep tokens // we set the last index to i+1 to indicate that the current iterate should be included in the set @@ -727,9 +810,21 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d last_idx = i + 1; break; } + + // we exceeded the current top-k heuristic -> increase k and continue + if (!cur_p->sorted && i == k - 1) { + k = cur_p->size; + llama_token_data_array_partial_sort(*cur_p, k, buf_sort); + pdata = buf_sort.data(); + } } // Resize the output vector to keep only the top-p tokens + if (!cur_p->sorted) { + std::copy(buf_sort.data(), buf_sort.data() + last_idx, cur_p->data); + cur_p->sorted = true; + } + cur_p->size = last_idx; } @@ -757,6 +852,7 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { /* .ctx = */ new llama_sampler_top_p { /* .p = */ p, /* .min_keep = */ min_keep, + /* .buf_sort = */ {}, } ); } @@ -773,7 +869,7 @@ static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl } static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_min_p *) smpl->ctx; + auto * ctx = (llama_sampler_min_p *) smpl->ctx; if (ctx->p <= 0.0f || !cur_p->size) { return; @@ -799,7 +895,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d // if we have enough values the operation was a success if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) { - memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); + std::copy(filtered_tokens.begin(), filtered_tokens.end(), cur_p->data); cur_p->size = filtered_tokens.size(); min_p_applied = true; } @@ -809,10 +905,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d if (!min_p_applied) { // Sort the logits in descending order if (!cur_p->sorted) { - std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - cur_p->sorted = true; + llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size); } const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max @@ -869,7 +962,7 @@ static const char * llama_sampler_typical_name(const struct llama_sampler * /*sm } static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_typical *) smpl->ctx; + auto * ctx = (llama_sampler_typical *) smpl->ctx; // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr @@ -878,7 +971,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token } // Compute the softmax of logits and calculate entropy - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); float entropy = 0.0f; for (size_t i = 0; i < cur_p->size; ++i) { @@ -1012,7 +1105,7 @@ static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*s } static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx; + auto * ctx = (llama_sampler_temp_ext *) smpl->ctx; if (ctx->delta > 0) { const float min_temp = std::max(0.0f, ctx->temp - ctx->delta); const float max_temp = ctx->temp + ctx->delta; @@ -1027,7 +1120,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke // Calculate maximum possible entropy float max_entropy = -logf(1.0f / cur_p->size); - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); // Calculate entropy of the softmax probabilities float entropy = 0.0f; @@ -1121,7 +1214,7 @@ struct llama_sampler_xtc { const uint32_t seed; uint32_t seed_cur; - std::mt19937 rng; + std::mt19937 rng; }; static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { @@ -1139,17 +1232,20 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data std::uniform_real_distribution distribution(0.0f, 1.0f); float chance = distribution(ctx->rng); - if (chance > ctx->probability) return; + if (chance > ctx->probability) { + return; + } - // in case it's not sorted/recalculated yet - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); int pos_last = 0; for (size_t i = 0; i < cur_p->size; ++i) { if (cur_p->data[i].p >= ctx->threshold) { pos_last = i; - } else break; + } else { + break; + } } if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) { @@ -1221,7 +1317,7 @@ struct llama_sampler_mirostat { float mu; - std::mt19937 rng; + std::mt19937 rng; }; static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) { @@ -1231,7 +1327,7 @@ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*s static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_mirostat *) smpl->ctx; - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; @@ -1250,7 +1346,8 @@ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_toke float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat); llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); - llama_sampler_softmax_impl(cur_p); + + llama_sampler_softmax_impl(cur_p, true); const int idx = llama_sample_dist(cur_p, ctx->rng); @@ -1336,7 +1433,7 @@ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); // Truncate the words with surprise values greater than mu cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { @@ -1348,7 +1445,7 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t } // Normalize the probabilities of the remaining words - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); const int idx = llama_sample_dist(cur_p, ctx->rng); @@ -1540,7 +1637,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0"); } trigger_pattern += ")[\\s\\S]*"; - auto trigger_pattern_c = trigger_pattern.c_str(); + const auto * trigger_pattern_c = trigger_pattern.c_str(); trigger_patterns = &trigger_pattern_c; num_trigger_patterns = 1; } @@ -1748,7 +1845,7 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * } static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; + auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; if (ctx->n <= 0.0f || cur_p->size <= 1) { return; @@ -1780,13 +1877,14 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t } float std = valid_count > 0 ? sqrt(acc/valid_count) : 0; - //apply mask + // apply mask for (size_t i = 0; i < cur_p->size; ++i) { if (cur_p->data[i].logit < max - (ctx->n * std)) { cur_p->data[i].logit = -INFINITY; } } - llama_sampler_softmax_impl(cur_p); + + llama_sampler_softmax_impl(cur_p, true); } static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) { @@ -1991,7 +2089,9 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat { const int last = last_n_repeat - 1; - int rt = 0, lt = 0; + + int rt = 0; + int lt = 0; for (int k = 1; k < last_n_repeat; ++k) { if (k > rt) { @@ -2135,8 +2235,8 @@ static struct llama_sampler_i llama_sampler_dry_i = { /* .free = */ llama_sampler_dry_free, }; -struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { - int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0); +struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { + int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? n_ctx_train : std::max(dry_penalty_last_n, 0); std::unordered_multimap> processed_breakers; const int MAX_CHAR_LEN = 40; const int MAX_SEQ_LEN = 20; @@ -2169,7 +2269,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, return llama_sampler_init( /* .iface = */ &llama_sampler_dry_i, /* .ctx = */ new llama_sampler_dry { - /* .total_context_size = */ context_size, + /* .total_context_size = */ n_ctx_train, /* .dry_multiplier = */ dry_multiplier, /* .dry_base = */ dry_base, /* .dry_allowed_length = */ dry_allowed_length, @@ -2308,7 +2408,7 @@ static const char * llama_sampler_infill_name(const struct llama_sampler * /*smp static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_infill *) smpl->ctx; - llama_sampler_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p, true); #if defined(GGML_DEBUG_SAMPLER_INFILL) #define LOG_DBG_CUR LLAMA_LOG_DEBUG diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index de5d1681dff..8cb36661a0c 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -434,6 +434,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_GROK_2: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -1955,7 +1962,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION; clean_spaces = false; } else if ( - tokenizer_pre == "bailingmoe") { + tokenizer_pre == "bailingmoe" || + tokenizer_pre == "llada-moe") { pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; clean_spaces = false; } else if ( @@ -1974,6 +1982,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "kimi-k2") { pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; clean_spaces = false; + } else if ( + tokenizer_pre == "grok-2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -2470,7 +2482,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // set attributes by model/tokenizer/architecture name if (false || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"}) - || _contains_any(general_arch, {"nomic-bert-moe"}) + || _contains_any(general_arch, {"nomic-bert-moe", "jina-bert-v3"}) ) { if (token_to_id.count("") == 0) { LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__); diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 61b81242168..0d2f28c36c8 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -47,6 +47,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, }; struct LLM_KV; diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index 34906cdb628..fe5a7a83548 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -25,6 +25,18 @@ // interface implementation // +const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type) { + switch (flash_attn_type) { + case LLAMA_FLASH_ATTN_TYPE_AUTO: + return "auto"; + case LLAMA_FLASH_ATTN_TYPE_DISABLED: + return "disabled"; + case LLAMA_FLASH_ATTN_TYPE_ENABLED: + return "enabled"; + } + GGML_ABORT("fatal error"); +} + struct llama_sampler_chain_params llama_sampler_chain_default_params() { struct llama_sampler_chain_params result = { /*.no_perf =*/ true, @@ -47,6 +59,7 @@ bool llama_supports_mlock(void) { bool llama_supports_gpu_offload(void) { return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr || + ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU) != nullptr || llama_supports_rpc(); } @@ -71,7 +84,9 @@ void llama_numa_init(enum ggml_numa_strategy numa) { GGML_ASSERT(dev && "CPU backend is not loaded"); auto * reg = ggml_backend_dev_backend_reg(dev); auto * numa_init_fn = (decltype(ggml_numa_init) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_numa_init"); - numa_init_fn(numa); + if (numa_init_fn) { + numa_init_fn(numa); + } } } @@ -170,8 +185,13 @@ static struct llama_model * llama_model_load_from_file_impl( model->devices.push_back(*dev); } } else { + // default device selection + + // build list of available devices + std::vector gpus; + std::vector igpus; std::vector rpc_servers; - // use all available devices + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { ggml_backend_dev_t dev = ggml_backend_dev_get(i); switch (ggml_backend_dev_type(dev)) { @@ -180,19 +200,51 @@ static struct llama_model * llama_model_load_from_file_impl( // skip CPU backends since they are handled separately break; - case GGML_BACKEND_DEVICE_TYPE_GPU: + case GGML_BACKEND_DEVICE_TYPE_GPU: { ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); if (ggml_backend_reg_name(reg) == std::string("RPC")) { rpc_servers.push_back(dev); } else { - model->devices.push_back(dev); + // check if there is already a GPU with the same device id + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + auto it = std::find_if(gpus.begin(), gpus.end(), [&props](ggml_backend_dev_t d) { + ggml_backend_dev_props d_props; + ggml_backend_dev_get_props(d, &d_props); + if (props.device_id && d_props.device_id) { + return strcmp(props.device_id, d_props.device_id) == 0; + } + return false; + }); + + if (it != gpus.end()) { + LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n", + __func__, + ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + props.device_id ? props.device_id : "unknown id", + ggml_backend_dev_name(*it), ggml_backend_dev_description(*it)); + } else { + gpus.push_back(dev); + } } break; + } + + case GGML_BACKEND_DEVICE_TYPE_IGPU: + igpus.push_back(dev); + break; } } - // add RPC servers at the front of the list - if (!rpc_servers.empty()) { - model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end()); + + // add RPC servers at the front of the list to minimize network transfers + model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end()); + + // add GPUs + model->devices.insert(model->devices.end(), gpus.begin(), gpus.end()); + + // add integrated GPUs only if no other devices were found + if (model->devices.empty()) { + model->devices.insert(model->devices.end(), igpus.begin(), igpus.end()); } } @@ -213,9 +265,12 @@ static struct llama_model * llama_model_load_from_file_impl( } for (auto * dev : model->devices) { - size_t free, total; // NOLINT - ggml_backend_dev_memory(dev, &free, &total); - LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024); + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__, + ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + props.device_id ? props.device_id : "unknown id", + props.memory_free/1024/1024); } const int status = llama_model_load(path_model, splits, *model, params); diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 135eaf1b655..453190e852b 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -64,8 +64,6 @@ extern "C" { typedef struct llama_memory_i * llama_memory_t; - struct llama_kv_cache; // DEPRECATED (use llama_memory instead) - typedef int32_t llama_pos; typedef int32_t llama_token; typedef int32_t llama_seq_id; @@ -181,6 +179,14 @@ extern "C" { LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1, }; + enum llama_flash_attn_type { + LLAMA_FLASH_ATTN_TYPE_AUTO = -1, + LLAMA_FLASH_ATTN_TYPE_DISABLED = 0, + LLAMA_FLASH_ATTN_TYPE_ENABLED = 1, + }; + + LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); + enum llama_split_mode { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs @@ -200,7 +206,7 @@ extern "C" { llama_token_data * data; size_t size; int64_t selected; // this is the index in the data array (i.e. not the token id) - bool sorted; + bool sorted; // note: do not assume the data is sorted - always check this flag } llama_token_data_array; typedef bool (*llama_progress_callback)(float progress, void * user_data); @@ -305,6 +311,7 @@ extern "C" { enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings + enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention // ref: https://github.com/ggml-org/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model @@ -314,7 +321,7 @@ extern "C" { float yarn_beta_fast; // YaRN low correction dim float yarn_beta_slow; // YaRN high correction dim uint32_t yarn_orig_ctx; // YaRN original context size - float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default) + float defrag_thold; // [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default) ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; @@ -331,7 +338,6 @@ extern "C" { // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU - bool flash_attn; // use flash attention [EXPERIMENTAL] bool no_perf; // measure performance timings bool op_offload; // offload host tensor operations to device bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) @@ -469,8 +475,6 @@ extern "C" { LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx); LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type - DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead"); - LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); @@ -557,10 +561,32 @@ extern "C" { struct llama_model * model, const char * path_lora); + // Functions to access the adapter's GGUF metadata scalar values + // - The functions return the length of the string on success, or -1 on failure + // - The output string is always null-terminated and cleared on failure + // - When retrieving a string, an extra byte must be allocated to account for the null terminator + // - GGUF array values are not supported by these functions + + // Get metadata value as a string by key name + LLAMA_API int32_t llama_adapter_meta_val_str(const struct llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size); + + // Get the number of metadata key/value pairs + LLAMA_API int32_t llama_adapter_meta_count(const struct llama_adapter_lora * adapter); + + // Get metadata key name by index + LLAMA_API int32_t llama_adapter_meta_key_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); + + // Get metadata value as a string by index + LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); + // Manually free a LoRA adapter // Note: loaded adapters will be free when the associated model is deleted LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); + // Get the invocation tokens if the current lora is an alora + LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter); + LLAMA_API const llama_token * llama_adapter_get_alora_invocation_tokens (const struct llama_adapter_lora * adapter); + // The following functions operate on a llama_context, hence the naming: llama_verb_... // Add a loaded LoRA adapter to given context @@ -667,111 +693,6 @@ extern "C" { // Check if the memory supports shifting LLAMA_API bool llama_memory_can_shift(llama_memory_t mem); - // - // KV cache for self-attention (TODO: deprecate in favor of llama_memory) - // - - // Returns the number of tokens in the KV cache (slow, use only for debug) - // If a KV cell has multiple sequences assigned to it, it will be counted multiple times - DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx), - "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); - - // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) - DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx), - "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); - - // Clear the KV cache - both cell info is erased and KV data is zeroed - DEPRECATED(LLAMA_API void llama_kv_self_clear( - struct llama_context * ctx), - "Use llama_memory_clear() instead"); - - // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) - // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails - // seq_id < 0 : match any sequence - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) - DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1), - "Use llama_memory_seq_rm() instead"); - - // Copy all tokens that belong to the specified sequence to another sequence - // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) - DEPRECATED(LLAMA_API void llama_kv_self_seq_cp( - struct llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1), - "Use llama_memory_seq_cp() instead"); - - // Removes all tokens that do not belong to the specified sequence - DEPRECATED(LLAMA_API void llama_kv_self_seq_keep( - struct llama_context * ctx, - llama_seq_id seq_id), - "Use llama_memory_seq_keep() instead"); - - // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) - // If the KV cache is RoPEd, the KV data is updated accordingly: - // - lazily on next llama_decode() - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) - DEPRECATED(LLAMA_API void llama_kv_self_seq_add( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta), - "Use llama_memory_seq_add() instead"); - - // Integer division of the positions by factor of `d > 1` - // If the KV cache is RoPEd, the KV data is updated accordingly: - // - lazily on next llama_decode() - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) - DEPRECATED(LLAMA_API void llama_kv_self_seq_div( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d), - "Use llama_memory_seq_div() instead"); - - // Returns the smallest position present in the KV cache for the specified sequence - // This is typically non-zero only for SWA caches - // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache - // Return -1 if the sequence is empty - DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min( - struct llama_context * ctx, - llama_seq_id seq_id), - "Use llama_memory_seq_pos_min() instead"); - - // Returns the largest position present in the KV cache for the specified sequence - // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache - // Return -1 if the sequence is empty - DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max( - struct llama_context * ctx, - llama_seq_id seq_id), - "Use llama_memory_seq_pos_max() instead"); - - // Defragment the KV cache - // This will be applied: - // - lazily on next llama_decode() - DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx), - "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); - - // Check if the context supports KV cache shifting - DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx), - "use llama_memory_can_shift() instead"); - - // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) - DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx), - "simply remove this call, updates are applied lazily on the next llama_decode()"); - // // State / sessions // @@ -1239,11 +1160,6 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); - /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. - DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), - "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)"); - /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// Setting k <= 0 makes this a noop LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index b4219c29446..239c56902d4 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -340,9 +340,10 @@ int main(int argc, char ** argv) { llama_context_params lcparams = llama_context_default_params(); // tune these to your liking - lcparams.n_ctx = 2048; - lcparams.n_threads = params.n_threads; - lcparams.flash_attn = params.flash_attn; + lcparams.n_ctx = 2048; + lcparams.n_threads = params.n_threads; + + lcparams.flash_attn_type = params.flash_attn ? LLAMA_FLASH_ATTN_TYPE_AUTO : LLAMA_FLASH_ATTN_TYPE_DISABLED; struct llama_context * ctx_llama = llama_init_from_model(model_llama, lcparams); diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 90e274ccdbc..b113b68fae6 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -1,5 +1,41 @@ cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories. -project("ggml" C CXX) +project("ggml" C CXX ASM) + +### GGML Version +set(GGML_VERSION_MAJOR 0) +set(GGML_VERSION_MINOR 9) +set(GGML_VERSION_PATCH 0) +set(GGML_VERSION_DEV "-dev") # "-dev" for development, "" for releases +set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") + +find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) +if(GIT_EXE) + # Get current git commit hash + execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GGML_BUILD_COMMIT + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + + # Check if the working directory is dirty (i.e., has uncommitted changes) + execute_process(COMMAND ${GIT_EXE} diff-index --quiet HEAD -- . + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE GGML_GIT_DIRTY + ERROR_QUIET + ) +endif() + +# Build the version string with optional -dev suffix and dirty flag +set(GGML_VERSION "${GGML_VERSION_BASE}${GGML_VERSION_DEV}") +if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0) + set(GGML_VERSION "${GGML_VERSION}-dirty") +endif() + +if(NOT GGML_BUILD_COMMIT) + set(GGML_BUILD_COMMIT "unknown") +endif() + include(CheckIncludeFileCXX) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -129,10 +165,11 @@ endif() option(GGML_LASX "ggml: enable lasx" ON) option(GGML_LSX "ggml: enable lsx" ON) option(GGML_RVV "ggml: enable rvv" ON) -option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF) +option(GGML_RV_ZFH "ggml: enable riscv zfh" ON) +option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) +option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) option(GGML_VXE "ggml: enable vxe" ON) -option(GGML_NNPA "ggml: enable nnpa" OFF) # temp disabled by default, see: https://github.com/ggml-org/llama.cpp/issues/14877 option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") @@ -158,7 +195,6 @@ option(GGML_CUDA "ggml: use CUDA" option(GGML_MUSA "ggml: use MUSA" OFF) option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF) option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF) -option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF) set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "ggml: max. batch size for using peer access") option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) @@ -301,26 +337,6 @@ endif() # Create CMake package # -# Generate version info based on git commit. - -if(NOT DEFINED GGML_BUILD_NUMBER) - find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH) - execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - OUTPUT_VARIABLE GGML_BUILD_NUMBER - OUTPUT_STRIP_TRAILING_WHITESPACE - ) - - if(GGML_BUILD_NUMBER EQUAL 1) - message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.") - endif() - - execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - OUTPUT_VARIABLE GGML_BUILD_COMMIT - OUTPUT_STRIP_TRAILING_WHITESPACE - ) -endif() # Capture variables prefixed with GGML_. @@ -349,7 +365,7 @@ set(GGML_VARIABLES_EXPANDED ${variable_set_statements}) # Create the CMake package and set install location. -set(GGML_INSTALL_VERSION 0.0.${GGML_BUILD_NUMBER}) +set(GGML_INSTALL_VERSION ${GGML_VERSION}) set(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") set(GGML_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") set(GGML_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index a2977ea2e56..ab297e0c6f6 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -132,6 +132,8 @@ extern "C" { GGML_BACKEND_DEVICE_TYPE_CPU, // GPU device using dedicated memory GGML_BACKEND_DEVICE_TYPE_GPU, + // integrated GPU device using host memory + GGML_BACKEND_DEVICE_TYPE_IGPU, // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) GGML_BACKEND_DEVICE_TYPE_ACCEL }; @@ -150,11 +152,21 @@ extern "C" { // all the device properties struct ggml_backend_dev_props { + // device name const char * name; + // device description const char * description; + // device free memory in bytes size_t memory_free; + // device total memory in bytes size_t memory_total; + // device type enum ggml_backend_dev_type type; + // device id + // for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0") + // if the id is unknown, this should be NULL + const char * device_id; + // device capabilities struct ggml_backend_dev_caps caps; }; @@ -307,6 +319,9 @@ extern "C" { GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); + // Split graph without allocating it + GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); + // Allocate and compute graph on the backend scheduler GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph); diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index be40b100979..9edd4851369 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -101,7 +101,6 @@ extern "C" { GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); GGML_BACKEND_API int ggml_cpu_has_vsx (void); GGML_BACKEND_API int ggml_cpu_has_vxe (void); - GGML_BACKEND_API int ggml_cpu_has_nnpa (void); GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void); GGML_BACKEND_API int ggml_cpu_has_llamafile (void); @@ -135,6 +134,7 @@ extern "C" { GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t); + GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t); GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t); diff --git a/ggml/include/ggml-metal.h b/ggml/include/ggml-metal.h index a6106944234..433838f0d6d 100644 --- a/ggml/include/ggml-metal.h +++ b/ggml/include/ggml-metal.h @@ -39,18 +39,13 @@ extern "C" { // user-code should use only these functions // +// TODO: remove in the future GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void); GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend); -GGML_DEPRECATED( - GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size), - "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713"); - GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); -GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); - // helper to check if the device supports a specific family // ideally, the user code should be doing these checks // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf diff --git a/ggml/include/ggml-zdnn.h b/ggml/include/ggml-zdnn.h index c2c30c977c3..69fb558d873 100644 --- a/ggml/include/ggml-zdnn.h +++ b/ggml/include/ggml-zdnn.h @@ -7,8 +7,6 @@ extern "C" { #endif -GGML_BACKEND_API ggml_backend_t ggml_backend_zdnn_init(void); - GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zdnn_reg(void); #ifdef __cplusplus diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index da8813fd278..36b23dc6d0d 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -244,6 +244,13 @@ #define GGML_MROPE_SECTIONS 4 #define GGML_UNUSED(x) (void)(x) +#ifdef __CUDACC__ +template +__host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexcept {} +#define GGML_UNUSED_VARS(...) ggml_unused_vars_impl(__VA_ARGS__) +#else +#define GGML_UNUSED_VARS(...) do { (void)sizeof((__VA_ARGS__, 0)); } while(0) +#endif // __CUDACC__ #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) @@ -277,19 +284,19 @@ // GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); // #define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \ - const type prefix##0 = (pointer)->array[0]; \ + const type prefix##0 = (pointer) ? (pointer)->array[0] : 0; \ GGML_UNUSED(prefix##0); #define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \ GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \ - const type prefix##1 = (pointer)->array[1]; \ + const type prefix##1 = (pointer) ? (pointer)->array[1] : 0; \ GGML_UNUSED(prefix##1); #define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \ GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \ - const type prefix##2 = (pointer)->array[2]; \ + const type prefix##2 = (pointer) ? (pointer)->array[2] : 0; \ GGML_UNUSED(prefix##2); #define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \ GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \ - const type prefix##3 = (pointer)->array[3]; \ + const type prefix##3 = (pointer) ? (pointer)->array[3] : 0; \ GGML_UNUSED(prefix##3); #define GGML_TENSOR_UNARY_OP_LOCALS \ @@ -504,7 +511,9 @@ extern "C" { GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, + GGML_OP_IM2COL_3D, GGML_OP_CONV_2D, + GGML_OP_CONV_3D, GGML_OP_CONV_2D_DW, GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, @@ -1395,6 +1404,7 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // note: casting from f32 to i32 will discard the fractional part GGML_API struct ggml_tensor * ggml_cast( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1519,7 +1529,11 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // supports 3D: a->ne[2] == b->ne[1] + // supports 4D a: + // a [n_embd, ne1, ne2, ne3] + // b I32 [n_rows, ne2, ne3, 1] + // + // return [n_embd, n_rows, ne2, ne3] GGML_API struct ggml_tensor * ggml_get_rows( struct ggml_context * ctx, struct ggml_tensor * a, // data @@ -1862,6 +1876,41 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_im2col_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t IC, + int s0, // stride width + int s1, // stride height + int s2, // stride depth + int p0, // padding width + int p1, // padding height + int p2, // padding depth + int d0, // dilation width + int d1, // dilation height + int d2, // dilation depth + enum ggml_type dst_type); + + // a: [OC*IC, KD, KH, KW] + // b: [N*IC, ID, IH, IW] + // result: [N*OC, OD, OH, OW] + GGML_API struct ggml_tensor * ggml_conv_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t IC, + int s0, // stride width + int s1, // stride height + int s2, // stride depth + int p0, // padding width + int p1, // padding height + int p2, // padding depth + int d0, // dilation width + int d1, // dilation height + int d2 // dilation depth + ); + // kernel size is a->ne[0] x a->ne[1] // stride is equal to kernel size // padding is zero @@ -1933,6 +1982,23 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_conv_3d_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC] + struct ggml_tensor * b, // input [W, H, D, C * N] + int s0, // stride + int s1, + int s2, + int p0, // padding + int p1, + int p2, + int d0, // dilation + int d1, + int d2, + int n_channels, + int n_batch, + int n_channels_out); + enum ggml_op_pool { GGML_OP_POOL_MAX, GGML_OP_POOL_AVG, @@ -2023,6 +2089,19 @@ extern "C" { int p2, int p3); + GGML_API struct ggml_tensor * ggml_pad_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3 + ); + // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c] GGML_API struct ggml_tensor * ggml_pad_reflect_1d( struct ggml_context * ctx, diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 2b5b8169d75..c8f3d859642 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -114,6 +114,9 @@ message(STATUS "GGML_SYSTEM_ARCH: ${GGML_SYSTEM_ARCH}") if (NOT MSVC) if (GGML_STATIC) + if (UNIX AND NOT APPLE) + set(CMAKE_FIND_LIBRARY_SUFFIXES ".a;.so") + endif() add_link_options(-static) if (MINGW) add_link_options(-static-libgcc -static-libstdc++) diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index c36c12d6579..07784d6f66c 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -8,7 +8,7 @@ extern "C" { #endif - #define GGML_BACKEND_API_VERSION 1 + #define GGML_BACKEND_API_VERSION 2 // // Backend buffer type @@ -114,6 +114,9 @@ extern "C" { void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event); // wait for an event on on a different stream void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event); + + // (optional) sort/optimize the nodes in the graph + void (*graph_optimize) (ggml_backend_t backend, struct ggml_cgraph * cgraph); }; struct ggml_backend { diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 5f02a710a14..7002cb07e00 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -400,9 +400,8 @@ ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const ggml_backend_t ggml_backend_init_best(void) { ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU); - if (!dev) { - dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - } + dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU); + dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); if (!dev) { return nullptr; } diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 1b9d29e911f..79a5282be37 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -19,9 +19,8 @@ #include #include #include -#include -#include #include +#include #ifdef __APPLE__ #include @@ -32,6 +31,7 @@ // backend buffer type const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(buft); return buft->iface.get_name(buft); } @@ -41,14 +41,17 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t return ggml_backend_buffer_init(buft, {}, NULL, 0); } + GGML_ASSERT(buft); return buft->iface.alloc_buffer(buft, size); } size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(buft); return buft->iface.get_alignment(buft); } size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(buft); // get_max_size is optional, defaults to SIZE_MAX if (buft->iface.get_max_size) { return buft->iface.get_max_size(buft); @@ -57,6 +60,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) { } size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { + GGML_ASSERT(buft); // get_alloc_size is optional, defaults to ggml_nbytes if (buft->iface.get_alloc_size) { size_t size = buft->iface.get_alloc_size(buft, tensor); @@ -67,6 +71,7 @@ size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const s } bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(buft); if (buft->iface.is_host) { return buft->iface.is_host(buft); } @@ -74,6 +79,7 @@ bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) { } ggml_backend_dev_t ggml_backend_buft_get_device(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(buft); return buft->device; } @@ -111,10 +117,12 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { } size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); return buffer->size; } void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); // get_base is optional if the buffer is zero-sized if (buffer->size == 0) { return NULL; @@ -128,6 +136,7 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { } enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + GGML_ASSERT(buffer); // init_tensor is optional if (buffer->iface.init_tensor) { return buffer->iface.init_tensor(buffer, tensor); @@ -136,6 +145,7 @@ enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, s } void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_ASSERT(buffer); // clear is optional if the buffer is zero-sized if (buffer->size == 0) { return; @@ -161,6 +171,7 @@ bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) { } void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { + GGML_ASSERT(buffer); buffer->usage = usage; // FIXME: add a generic callback to the buffer interface @@ -170,14 +181,17 @@ void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backe } enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); return buffer->usage; } ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); return buffer->buft; } void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); if (buffer->iface.reset) { buffer->iface.reset(buffer); } @@ -216,6 +230,7 @@ void ggml_backend_free(ggml_backend_t backend) { } ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) { + GGML_ASSERT(backend); return ggml_backend_dev_buffer_type(backend->device); } @@ -232,6 +247,8 @@ size_t ggml_backend_get_max_size(ggml_backend_t backend) { } void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); @@ -243,6 +260,8 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * } void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); @@ -284,6 +303,7 @@ void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, siz } void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; if (size == 0) { @@ -299,6 +319,7 @@ void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size } void ggml_backend_synchronize(ggml_backend_t backend) { + GGML_ASSERT(backend); if (backend->iface.synchronize == NULL) { return; } @@ -307,18 +328,21 @@ void ggml_backend_synchronize(ggml_backend_t backend) { } ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(backend); GGML_ASSERT(backend->iface.graph_plan_create != NULL); return backend->iface.graph_plan_create(backend, cgraph); } void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(backend); GGML_ASSERT(backend->iface.graph_plan_free != NULL); backend->iface.graph_plan_free(backend, plan); } enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(backend); GGML_ASSERT(backend->iface.graph_plan_compute != NULL); return backend->iface.graph_plan_compute(backend, plan); @@ -331,22 +355,27 @@ enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_ } enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(backend); return backend->iface.graph_compute(backend, cgraph); } bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { + GGML_ASSERT(backend); return ggml_backend_dev_supports_op(backend->device, op); } bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { + GGML_ASSERT(backend); return ggml_backend_dev_supports_buft(backend->device, buft); } bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) { + GGML_ASSERT(backend); return ggml_backend_dev_offload_op(backend->device, op); } ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) { + GGML_ASSERT(backend); return backend->device; } @@ -382,6 +411,7 @@ void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t b return; } + GGML_ASSERT(backend_dst); if (backend_dst->iface.cpy_tensor_async != NULL) { if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) { return; @@ -413,38 +443,52 @@ void ggml_backend_event_free(ggml_backend_event_t event) { } void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend) { + GGML_ASSERT(backend); GGML_ASSERT(backend->iface.event_record != NULL); backend->iface.event_record(backend, event); } void ggml_backend_event_synchronize(ggml_backend_event_t event) { + GGML_ASSERT(event); GGML_ASSERT(event->device->iface.event_synchronize); event->device->iface.event_synchronize(event->device, event); } void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + GGML_ASSERT(backend); GGML_ASSERT(backend->iface.event_wait != NULL); backend->iface.event_wait(backend, event); } +static void ggml_backend_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(backend); + if (backend->iface.graph_optimize != NULL) { + backend->iface.graph_optimize(backend, cgraph); + } +} + // Backend device const char * ggml_backend_dev_name(ggml_backend_dev_t device) { + GGML_ASSERT(device); return device->iface.get_name(device); } const char * ggml_backend_dev_description(ggml_backend_dev_t device) { + GGML_ASSERT(device); return device->iface.get_description(device); } void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + GGML_ASSERT(device); device->iface.get_memory(device, free, total); } enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) { + GGML_ASSERT(device); return device->iface.get_type(device); } @@ -454,18 +498,22 @@ void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_d } ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device) { + GGML_ASSERT(device); return device->reg; } ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params) { + GGML_ASSERT(device); return device->iface.init_backend(device, params); } ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) { + GGML_ASSERT(device); return device->iface.get_buffer_type(device); } ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) { + GGML_ASSERT(device); if (device->iface.get_host_buffer_type == NULL) { return NULL; } @@ -474,18 +522,22 @@ ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t } ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size) { + GGML_ASSERT(device); return device->iface.buffer_from_host_ptr(device, ptr, size, max_tensor_size); } bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op) { + GGML_ASSERT(device); return device->iface.supports_op(device, op); } bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft) { + GGML_ASSERT(device); return device->iface.supports_buft(device, buft); } bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) { + GGML_ASSERT(device); if (device->iface.offload_op != NULL) { return device->iface.offload_op(device, op); } @@ -496,18 +548,22 @@ bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_te // Backend (reg) const char * ggml_backend_reg_name(ggml_backend_reg_t reg) { + GGML_ASSERT(reg); return reg->iface.get_name(reg); } size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg) { + GGML_ASSERT(reg); return reg->iface.get_device_count(reg); } ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(reg); return reg->iface.get_device(reg, index); } void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { + GGML_ASSERT(reg); if (!reg->iface.get_proc_address) { return NULL; } @@ -522,6 +578,7 @@ struct ggml_backend_multi_buffer_context { }; static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; for (size_t i = 0; i < ctx->n_buffers; i++) { ggml_backend_buffer_free(ctx->buffers[i]); @@ -532,6 +589,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) } static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_ASSERT(buffer); ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; for (size_t i = 0; i < ctx->n_buffers; i++) { ggml_backend_buffer_clear(ctx->buffers[i], value); @@ -567,10 +625,12 @@ ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer } bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); return buffer->iface.free_buffer == ggml_backend_multi_buffer_free_buffer; } void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { + GGML_ASSERT(buffer); GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer)); ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; for (size_t i = 0; i < ctx->n_buffers; i++) { @@ -598,7 +658,7 @@ static bool ggml_is_view_op(enum ggml_op op) { #endif #ifndef GGML_SCHED_MAX_SPLIT_INPUTS -#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC +#define GGML_SCHED_MAX_SPLIT_INPUTS 30 #endif #ifndef GGML_SCHED_MAX_COPIES @@ -849,7 +909,7 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru } // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend -static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { +void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { // reset splits sched->n_splits = 0; sched->n_graph_inputs = 0; @@ -1245,6 +1305,10 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg struct ggml_backend_sched_split * split = &sched->splits[i]; split->graph = ggml_graph_view(graph, split->i_start, split->i_end); + // Optimize this split of the graph. This needs to happen before we make graph_copy, + // so they are in sync. + ggml_backend_graph_optimize(sched->backends[split->backend_id], &split->graph); + // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split for (int j = 0; j < split->n_inputs; j++) { assert(graph_copy->size > (graph_copy->n_nodes + 1)); @@ -1350,17 +1414,22 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { } static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) { + GGML_ASSERT(sched); struct ggml_backend_sched_split * splits = sched->splits; - for (int i = 0; i < sched->n_splits; i++) { - struct ggml_backend_sched_split * split = &splits[i]; + ggml_tensor * prev_ids_tensor = nullptr; + std::vector ids; + std::vector used_ids; + + for (int split_id = 0; split_id < sched->n_splits; split_id++) { + struct ggml_backend_sched_split * split = &splits[split_id]; int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; // copy the input tensors to the split backend - for (int j = 0; j < split->n_inputs; j++) { - ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]); - struct ggml_tensor * input = split->inputs[j]; + for (int input_id = 0; input_id < split->n_inputs; input_id++) { + ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]); + struct ggml_tensor * input = split->inputs[input_id]; struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy); if (input->flags & GGML_TENSOR_FLAG_INPUT) { @@ -1378,16 +1447,104 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } else { ggml_backend_synchronize(split_backend); } - // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events - // TODO: add public function to facilitate this, since applications do not have direct access to the backend interface - if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) { + + // when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used + ggml_tensor * node = split->graph.nodes[0]; + if (split->graph.n_nodes > 0 && + ggml_backend_buffer_get_usage(input->buffer) == GGML_BACKEND_BUFFER_USAGE_WEIGHTS && + ggml_backend_buffer_is_host(input->buffer) && ( + (node->src[0] == input_cpy && node->op == GGML_OP_MUL_MAT_ID) + //|| (node->src[1] == input_cpy && node->op == GGML_OP_ADD_ID) /* GGML_OP_ADD_ID weights are small and not worth splitting */ + )) { + + const int64_t n_expert = node->op == GGML_OP_MUL_MAT_ID ? input->ne[2] : input->ne[1]; + const size_t expert_size = node->op == GGML_OP_MUL_MAT_ID ? input->nb[2] : input->nb[1]; + ggml_backend_synchronize(input_backend); - if (sched->events[split_backend_id][sched->cur_copy] != NULL) { - ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); - } else { - ggml_backend_synchronize(split_backend); + + // get the ids + ggml_tensor * ids_tensor = node->src[2]; + ggml_backend_t ids_backend = split_backend; + + // if the ids tensor is also an input of the split, it may not have been copied yet to the split backend + // in that case, we use the original ids tensor + for (int i = input_id + 1; i < split->n_inputs; i++) { + if (ids_tensor == tensor_copy(split->inputs[i], split_backend_id, sched->cur_copy)) { + ids_tensor = split->inputs[i]; + ids_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[i]); + break; + } + } + + if (ids_tensor != prev_ids_tensor) { + ids.resize(ggml_nbytes(ids_tensor) / sizeof(int32_t)); + ggml_backend_tensor_get_async(ids_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor)); + ggml_backend_synchronize(ids_backend); + + // find the used experts + used_ids.clear(); + used_ids.resize(ggml_bitset_size(n_expert)); + for (int64_t i1 = 0; i1 < ids_tensor->ne[1]; i1++) { + for (int64_t i0 = 0; i0 < ids_tensor->ne[0]; i0++) { + int32_t id = ids[i1 * ids_tensor->nb[1]/sizeof(int32_t) + i0 * ids_tensor->nb[0]/sizeof(int32_t)]; + GGML_ASSERT(id >= 0 && id < n_expert); + ggml_bitset_set(used_ids.data(), id); + } + } + + prev_ids_tensor = ids_tensor; + } + + // group consecutive experts and copy them together + auto copy_experts = [&](int32_t first_id, int32_t last_id) { + const size_t expert_offset = first_id * expert_size; + const size_t expert_size_copy = (last_id - first_id + 1) * expert_size; + const size_t padding = std::min(expert_size, 512); + const size_t padding_end = last_id < n_expert - 1 ? padding : 0; + + ggml_backend_tensor_set_async(split_backend, + input_cpy, + (const uint8_t *)input->data + expert_offset, expert_offset, + // copy a bit extra at the to ensure there are no NaNs in the padding of the last expert + // this is necessary for MMQ in the CUDA backend + expert_size_copy + padding_end); + }; + + int id = 0; + while (!ggml_bitset_get(used_ids.data(), id)) { + id++; + } + int32_t first_id = id; + int32_t last_id = first_id; + + for (++id; id < n_expert; ++id) { + if (!ggml_bitset_get(used_ids.data(), id)) { + continue; + } + + if (id == last_id + 1) { + last_id = id; + continue; + } + + copy_experts(first_id, last_id); + + first_id = id; + last_id = id; + } + copy_experts(first_id, last_id); + } else { + // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events + // TODO: add public function to facilitate this, since applications do not have direct access to the backend interface + if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) { + ggml_backend_synchronize(input_backend); + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); + } else { + ggml_backend_synchronize(split_backend); + } + ggml_backend_tensor_copy(input, input_cpy); } - ggml_backend_tensor_copy(input, input_cpy); } } } @@ -1526,6 +1683,7 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { } void ggml_backend_sched_reset(ggml_backend_sched_t sched) { + GGML_ASSERT(sched); // reset state for the next run if (!sched->is_reset) { ggml_hash_set_reset(&sched->hash_set); @@ -1537,8 +1695,11 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) { } bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { + GGML_ASSERT(sched); GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs); + ggml_backend_sched_reset(sched); + ggml_backend_sched_synchronize(sched); ggml_backend_sched_split_graph(sched, measure_graph); @@ -1553,6 +1714,7 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * } bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + GGML_ASSERT(sched); GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs); GGML_ASSERT(!sched->is_alloc); @@ -1577,6 +1739,7 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st } enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + GGML_ASSERT(sched); if (!sched->is_reset && !sched->is_alloc) { ggml_backend_sched_reset(sched); } @@ -1591,6 +1754,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sch } void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { + GGML_ASSERT(sched); for (int i = 0; i < sched->n_backends; i++) { ggml_backend_synchronize(sched->backends[i]); } @@ -1603,28 +1767,34 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { } void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) { + GGML_ASSERT(sched); sched->callback_eval = callback; sched->callback_eval_user_data = user_data; } int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) { + GGML_ASSERT(sched); return sched->n_splits; } int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) { + GGML_ASSERT(sched); return sched->n_copies; } int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) { + GGML_ASSERT(sched); return sched->n_backends; } ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) { + GGML_ASSERT(sched); GGML_ASSERT(i >= 0 && i < sched->n_backends); return sched->backends[i]; } size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { + GGML_ASSERT(sched); int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); @@ -1632,6 +1802,7 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe } void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { + GGML_ASSERT(sched); int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); tensor_backend_id(node) = backend_index; @@ -1640,6 +1811,7 @@ void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct gg } ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) { + GGML_ASSERT(sched); int backend_index = tensor_backend_id(node); if (backend_index == -1) { return NULL; @@ -1650,6 +1822,7 @@ ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, // utils enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) { + GGML_ASSERT(tensor); GGML_ASSERT(tensor->buffer == NULL); GGML_ASSERT(tensor->view_src != NULL); GGML_ASSERT(tensor->view_src->buffer != NULL); @@ -1661,6 +1834,7 @@ enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) { } enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) { + GGML_ASSERT(tensor); GGML_ASSERT(tensor->buffer == NULL); GGML_ASSERT(tensor->data == NULL); GGML_ASSERT(tensor->view_src == NULL); @@ -1734,6 +1908,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_ } struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) { + GGML_ASSERT(graph); struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size); struct ggml_tensor ** node_copies = (ggml_tensor **) calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT bool * node_init = (bool *) calloc(hash_set.size, sizeof(node_init[0])); @@ -1878,6 +2053,7 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t // CPU backend - buffer static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); uintptr_t data = (uintptr_t)buffer->context; // align the buffer @@ -1889,28 +2065,33 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { } static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); ggml_aligned_free(buffer->context, buffer->size); } static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + GGML_ASSERT(tensor); memset((char *)tensor->data + offset, value, size); GGML_UNUSED(buffer); } static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor); memcpy((char *)tensor->data + offset, data, size); GGML_UNUSED(buffer); } static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor); memcpy(data, (const char *)tensor->data + offset, size); GGML_UNUSED(buffer); } static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + GGML_ASSERT(src); if (ggml_backend_buffer_is_host(src->buffer)) { memcpy(dst->data, src->data, ggml_nbytes(src)); return true; @@ -1921,6 +2102,7 @@ static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con } static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_ASSERT(buffer); memset(buffer->context, value, buffer->size); } diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index aeac2e57449..5b888cdd8cd 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -270,6 +270,7 @@ static struct ggml_backend_i blas_backend_i = { /* .graph_compute = */ ggml_backend_blas_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, }; static ggml_guid_t ggml_backend_blas_guid(void) { diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 259a2928b1f..434023dd22a 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -70,6 +70,8 @@ #include #include #include +#include +#include #include #include @@ -587,9 +589,16 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // the position of elements in the array means which dirction to padding, // each position means: [dim0.front, dim0.behind, dim1.front, dim1.behind, // dim2.front, dim2.behind, dim3.front, dim3.behind] - int64_t paddings[] = { - 0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1], - 0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]}; + const int32_t lp0 = ggml_get_op_params_i32(dst, 0); + const int32_t rp0 = ggml_get_op_params_i32(dst, 1); + const int32_t lp1 = ggml_get_op_params_i32(dst, 2); + const int32_t rp1 = ggml_get_op_params_i32(dst, 3); + const int32_t lp2 = ggml_get_op_params_i32(dst, 4); + const int32_t rp2 = ggml_get_op_params_i32(dst, 5); + const int32_t lp3 = ggml_get_op_params_i32(dst, 6); + const int32_t rp3 = ggml_get_op_params_i32(dst, 7); + + int64_t paddings[] = {lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3}; aclnn_pad(ctx, acl_src, acl_dst, paddings); ggml_cann_release_resources(ctx, acl_src, acl_dst); } @@ -867,6 +876,86 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer, return acl_tensor; } +/** + * @brief Fills a tensor with a scalar value. + * + * This function fills the destination tensor `acl_dst` with the scalar value + * `scalar`. + * + * @param ctx The context for the CANN backend operations. + * @param scalar The scalar value used to fill the tensor. + * @param acl_dst The destination tensor to be filled with the scalar value. + */ +static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, + aclTensor* acl_dst) { + auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar); + ggml_cann_release_resources(ctx, acl_scalar); +} + +/** + * @brief Get or expand a cached float32 tensor filled with a scalar value. + * + * This function manages cached device memory for float32 tensors. If the current + * cache size is insufficient for the requested tensor shape, the old memory will + * be released and new memory will be allocated. The allocated buffer is then + * initialized either with zeros (when @p value == 0.0f) or with the given scalar + * value using CANN operations. Finally, an aclTensor object is created from the + * cached memory and returned. + * + * @param ctx The CANN backend context that manages device memory. + * @param buffer A pointer to the cached device buffer (will be allocated + * or reallocated if necessary). + * @param cache_element The current number of cached elements. This will be + * updated when the cache is expanded. + * @param ne The tensor shape array (number of elements in each dimension). + * @param nb The stride size for each dimension. + * @param dims The number of tensor dimensions. + * @param value The scalar value used to fill the tensor (supports zero + * initialization via memset or arbitrary values via fill_scalar). + * @return An aclTensor pointer created from the cached buffer. + */ +static aclTensor* get_f32_cache_acl_tensor( + ggml_backend_cann_context& ctx, + void** buffer, + int64_t &cache_element, + int64_t* ne, + size_t* nb, + int64_t dims, + float value) { + // Calculate total number of elements + int64_t n_element = 1; + for (int i = 0; i < dims; i++) { + n_element *= ne[i]; + } + size_t size = n_element * sizeof(float); + + // Allocate or expand cache if needed + if (cache_element < n_element) { + if (*buffer != nullptr) { + aclrtFree(*buffer); + *buffer = nullptr; + } + + ACL_CHECK(aclrtMalloc(buffer, size, ACL_MEM_MALLOC_HUGE_FIRST)); + cache_element = n_element; + + // Initialize cache + if (value == 0.0f) { + ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream())); + } else { + int64_t pool_ne[1] = { n_element }; + size_t pool_nb[1] = { sizeof(float) }; + aclTensor* acl_value = ggml_cann_create_tensor( + *buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1); + aclnn_fill_scalar(ctx, 1, acl_value); + ggml_cann_release_resources(ctx, acl_value); + } + } + + return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims); +} + void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src = dst->src[0]; @@ -875,20 +964,40 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); - size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src); - ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes); - - aclTensor* acl_gamma = aclnn_values( - ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne, 1, - ggml_cann_type_mapping(src->type), ggml_element_size(src)); - - size_t zero_tensor_n_bytes = - src->ne[1] * src->ne[2] * src->ne[3] * ggml_element_size(src); - ggml_cann_pool_alloc zero_tensor_allocator(ctx.pool(), zero_tensor_n_bytes); - aclTensor* acl_rstd = - aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes, - src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type), - ggml_element_size(src)); + + // build gamma, one... + size_t acl_gamma_nb[GGML_MAX_DIMS]; + acl_gamma_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1]; + } + aclTensor* acl_gamma = get_f32_cache_acl_tensor( + ctx, + &ctx.rms_norm_one_tensor_cache.cache, + ctx.rms_norm_one_tensor_cache.size, + src->ne, + acl_gamma_nb, + 1, // dims + 1.0f // value + ); + + // build rstd, zero... + int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]}; + size_t acl_rstd_nb[GGML_MAX_DIMS - 1]; + acl_rstd_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS - 1; i++) { + acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1]; + } + aclTensor* acl_rstd = get_f32_cache_acl_tensor( + ctx, + &ctx.rms_norm_zero_tensor_cache.cache, + ctx.rms_norm_zero_tensor_cache.size, + acl_rstd_ne, + acl_rstd_nb, + GGML_MAX_DIMS - 1, + 0.0f // value + ); + GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd); ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd); } @@ -903,14 +1012,13 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, const int n_past = ((int32_t*)dst->op_params)[0]; - size_t one_tensor_n_bytes = src->ne[0] * src->ne[1] * src->ne[2] * - src->ne[3] * ggml_element_size(src); - ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes); + ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), ggml_nbytes(src)); + void* buffer = one_tensor_allocator.get(); + + aclTensor* mask_tensor = ggml_cann_create_tensor(buffer, ggml_cann_type_mapping(src->type), + ggml_type_size(src->type), src->ne, src->nb, GGML_MAX_DIMS); - aclTensor* mask_tensor = - aclnn_values(ctx, one_tensor_allocator.get(), one_tensor_n_bytes, - src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type), - ggml_element_size(src), value); + aclnn_fill_scalar(ctx, value, mask_tensor); aclScalar* alpha = nullptr; float alphaValue = 1.0f; @@ -1159,12 +1267,20 @@ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) { void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { - GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst); + if(acl_dst == nullptr) { + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src); + } else { + GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst); + } } void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { - GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst); + if(acl_dst == nullptr) { + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src); + } else { + GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst); + } } void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, @@ -1277,23 +1393,6 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, tmp_permute_tensor, tmp_mul_tensor, acl_dst); } -/** - * @brief Fills a tensor with a scalar value. - * - * This function fills the destination tensor `acl_dst` with the scalar value - * `scalar`. - * - * @param ctx The context for the CANN backend operations. - * @param scalar The scalar value used to fill the tensor. - * @param acl_dst The destination tensor to be filled with the scalar value. - */ -static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, - aclTensor* acl_dst) { - auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar); - ggml_cann_release_resources(ctx, acl_scalar); -} - /** * @brief Raises each element of a tensor to the power of the corresponding * element in another tensor. @@ -1334,21 +1433,25 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, * @param start Starting exponent offset. * @param stop Stopping exponent offset (exclusive). * @param step Step size for the exponent increment. + * @param dtype Data type for slope tensor. */ static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer, - float m, int64_t size, float start, float stop, float step){ + float m, int64_t size, float start, float stop, float step, ggml_type dtype){ + aclDataType acl_type = ggml_cann_type_mapping(dtype); + size_t type_size = ggml_type_size(dtype); + int64_t ne[] = {size}; - size_t nb[] = {sizeof(float)}; + size_t nb[] = {type_size}; - ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float)); + ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * type_size); void* arange_buffer = arange_allocator.get(); aclTensor* arange_tensor = ggml_cann_create_tensor( - arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1); + arange_buffer, acl_type, type_size, ne, nb, 1); aclnn_arange(ctx, arange_tensor, start, stop, step, size); aclTensor* slope_tensor = ggml_cann_create_tensor( - slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1); + slope_buffer, acl_type, type_size, ne, nb, 1); aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT); @@ -1379,10 +1482,11 @@ static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_bu * @param n_head Total number of attention heads. * @param slope_buffer Pointer to the output buffer (float array) for storing slopes. * @param max_bias Maximum bias value for slope computation. + * @param dtype Data type for slope tensor. * */ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, - void* slope_buffer, float max_bias) { + void* slope_buffer, float max_bias, ggml_type dtype) { const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); float m0 = powf(2.0f, -(max_bias) / n_head_log2); @@ -1399,7 +1503,7 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, float step = 1; float count = n_head_log2; // end needs to be +1 because aclnn uses a left-closed, right-open interval. - aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step); + aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step, dtype); if (n_head_log2 < n_head) { // arange2 start = 2 * (n_head_log2 - n_head_log2) + 1; @@ -1408,7 +1512,7 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, count = n_head - n_head_log2; aclnn_get_slope_inner( ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), - m1, count, start, end + 1, step); + m1, count, start, end + 1, step, dtype); } } @@ -1445,7 +1549,7 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask, ggml_cann_pool_alloc bias_allocator( ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst)); bias_buffer = bias_allocator.get(); - aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias); + aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias, GGML_TYPE_F32); } // broadcast for mask, slop and dst; @@ -1671,10 +1775,10 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { case GGML_TYPE_F16: { aclTensor* acl_src0 = ggml_cann_create_tensor(src0); ggml_cann_pool_alloc src_buffer_allocator( - ctx.pool(), ggml_nelements(src0) * sizeof(float_t)); + ctx.pool(), ggml_nelements(src0) * sizeof(float)); void* src_trans_buffer = src_buffer_allocator.get(); size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = sizeof(float_t); + src_trans_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; } @@ -1718,14 +1822,14 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // [3,4,5,64] -> [3,4,5,2,32] dequant_ne = weight_ne; - dequant_nb[0] = sizeof(float_t); + dequant_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; } scale_offset = ggml_nelements(src0) * sizeof(int8_t); ggml_cann_pool_alloc dequant_buffer_allocator( - ctx.pool(), ggml_nelements(src0) * sizeof(float_t)); + ctx.pool(), ggml_nelements(src0) * sizeof(float)); aclTensor* acl_weight_tensor = ggml_cann_create_tensor( src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, @@ -1734,11 +1838,11 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); aclTensor* dequant_tensor = ggml_cann_create_tensor( - dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float_t), + dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor); - dequant_nb[0] = sizeof(float_t); + dequant_nb[0] = sizeof(float); dequant_ne = src0->ne; for (int i = 1; i < GGML_MAX_DIMS; i++) { dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; @@ -1859,7 +1963,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, aclTensor* acl_weight_tensor; // Only check env once. - static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("")); + static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); if (weight_to_nz && is_matmul_weight(weight)) { int64_t acl_stride[2] = {1, transpose_ne[1]}; @@ -2140,63 +2244,190 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, ggml_cann_release_resources(ctx, acl_index, acl_value); } +/** + * @brief Initializes and caches sine/cosine positional encoding values + * (used in RoPE, Rotary Position Embedding) for attention layers. + * + * This function computes and caches the sin/cos values of + * θ = position * theta_scale for RoPE encoding. The cache is shared + * across attention layers, and only the first attention layer will + * trigger initialization. The cache includes repeated sin/cos values + * with different repeat methods depending on the @param is_neox flag. + * + * Steps performed by this function: + * 1. Identify whether the target tensor belongs to Q/K in attention + * and restrict computation to the first layer only. + * 2. Initialize the theta scale array (arange → power → freq scaling). + * 3. Allocate sin/cos caches if the max prompt length increases. + * 4. Compute θ = position * theta_scale. + * 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor. + * 6. Expand sin/cos values by repeat or repeat_interleave depending + * on whether @param is_neox is enabled. + * + * @param ctx The CANN backend context, holding memory pool, + * stream, and persistent buffers for rope init/cache. + * @param dst The destination ggml_tensor whose computation + * depends on the RoPE values (usually Qcur/Kcur). + * @param theta_scale Scalar exponent base for computing theta scale values. + * @param freq_scale Frequency scaling factor, applied to theta scale. + * @param attn_factor Attention scaling factor, applied to sin/cos. + * @param is_neox Whether to use Neox-style repeat strategy + * (dim expansion vs repeat_interleave). + */ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, - aclTensor* acl_cos_repeat_tensor, - aclTensor* acl_sin_repeat_tensor, + float* corr_dims, float ext_factor, float theta_scale, float freq_scale, float attn_factor, bool is_neox) { - // int sin/cos cache, cache has different repeat method depond on - // @param.is_neox - ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src1 = dst->src[1]; // position ggml_tensor* src2 = dst->src[2]; // freq_factors - GGML_TENSOR_BINARY_OP_LOCALS + if(src2 == nullptr && ctx.rope_cache.cached + && ctx.rope_cache.ext_factor == ext_factor + && ctx.rope_cache.theta_scale == theta_scale + && ctx.rope_cache.freq_scale == freq_scale + && ctx.rope_cache.attn_factor == attn_factor + && ctx.rope_cache.is_neox == is_neox) { + // use cache. + return; + } - // theta_scale arange, [0,1,...,ne00/2 - 1] - int64_t theta_scale_length = ne00 / 2; - ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(), - theta_scale_length * sizeof(float_t)); - void* theta_scale_buffer = theta_scale_allocator.get(); + int64_t theta_scale_length = src0->ne[0] / 2; int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1}; - size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t), - theta_scale_length * sizeof(float_t)}; + size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float), + theta_scale_length * sizeof(float)}; - aclTensor* acl_theta_scale_tensor = - ggml_cann_create_tensor(theta_scale_buffer, ACL_FLOAT, sizeof(float_t), - theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); - float start = 0; - float step = 1; - float stop = ne00 / 2; - float n_elements = ne00 / 2; - aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements); - - // power - aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor, - acl_theta_scale_tensor); - - // freq_scale - if (freq_scale != 1) { - aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true); + GGML_ASSERT(src1->type == GGML_TYPE_I32); + int64_t position_length = src1->ne[0]; + int64_t position_ne[] = {1, 1, position_length, 1}; + size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), + sizeof(int32_t) * position_length}; + + int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1}; + size_t theta_nb[GGML_MAX_DIMS]; + theta_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1]; } + // theta_scale arange, [0,1,...,ne00/2 - 1] + aclTensor* acl_theta_scale_tensor = nullptr; + // cache theta scale + if (ctx.rope_cache.theta_scale_length != theta_scale_length || + // theta_scale and freq_scale should not change during the current token inference process, + // so we can directly use == here instead of comparing the absolute difference. + ctx.rope_cache.theta_scale != theta_scale || + ctx.rope_cache.freq_scale != freq_scale) { + + ctx.rope_cache.theta_scale_length = theta_scale_length; + + if (ctx.rope_cache.theta_scale_cache != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache)); + } + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + + acl_theta_scale_tensor = + ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + + float start = 0; + float step = 1; + float stop = theta_scale_length; + float n_elements = theta_scale_length; + aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements); + + ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); + aclTensor* acl_yarn_ramp_tensor = nullptr; + if (ext_factor != 0) { + // -rope_yarn_ramp + // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); + // return MIN(1, MAX(0, y)) - 1; + yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); + void* yarn_ramp_buffer = yarn_ramp_allocator.get(); + acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + float zero_value = 0, one_value = 1; + float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); + aclScalar* low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT); + aclScalar* zero = aclCreateScalar(&zero_value, aclDataType::ACL_FLOAT); + aclScalar* one = aclCreateScalar(&one_value, aclDataType::ACL_FLOAT); + aclScalar* denom_safe = aclCreateScalar(&denom_safe_value, aclDataType::ACL_FLOAT); + aclScalar* ext_factor_sc = aclCreateScalar(&ext_factor, aclDataType::ACL_FLOAT); + + GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor, low, one, acl_yarn_ramp_tensor); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor, denom_safe); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor, zero, zero); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor, one); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor, one, one); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, ext_factor_sc); + + // theta_interp = freq_scale * theta_extrap; + // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix; + // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix; + // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix); + // + // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse + // cache freq_scale + (freq_scale - 1) * ramp_mix + float freq_scale_1 = freq_scale - 1; + aclScalar* freq_scale_sc = aclCreateScalar(&freq_scale, aclDataType::ACL_FLOAT); + aclScalar* freq_scale_1_sc = aclCreateScalar(&freq_scale_1, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, freq_scale_1_sc); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor, freq_scale_sc, one); + + ggml_cann_release_resources(ctx, low, zero, one, denom_safe, ext_factor_sc, freq_scale_sc, freq_scale_1_sc); + } + + // power + aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor, + acl_theta_scale_tensor); + + if (ext_factor != 0) { + aclnn_mul(ctx, acl_theta_scale_tensor, acl_yarn_ramp_tensor); + } else if (freq_scale != 1) { + aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true); + } + + ggml_cann_release_resources(ctx, acl_yarn_ramp_tensor, acl_theta_scale); + } else { + // use cache + acl_theta_scale_tensor = + ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + } + + ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool()); // freq_factors if (src2) { + freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float)); + void* freq_fac_res_ptr = freq_fac_res_allocator.get(); aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor( src2->data, ggml_cann_type_mapping(src2->type), ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); - aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor); - ggml_cann_release_resources(ctx, acl_freq_factors_tensor); + aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor( + freq_fac_res_ptr, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor); + std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor); + ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor); + } + + // init sin_repeat && cos_repeat, only to accelerate first layer on each device + if (position_length > ctx.rope_cache.position_length) { + ctx.rope_cache.position_length = position_length; + if (ctx.rope_cache.sin_cache != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.sin_cache)); + } + if (ctx.rope_cache.cos_cache != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache)); + } + int64_t repeat_theta_length = theta_scale_length * position_length * 2; + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); } // position - GGML_ASSERT(src1->type == GGML_TYPE_I32); - int64_t position_length = src1->ne[0]; - int64_t position_ne[] = {1, 1, position_length, 1}; - size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), - sizeof(int32_t) * position_length}; aclTensor* acl_position_tensor = ggml_cann_create_tensor( src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS); @@ -2204,43 +2435,55 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, // power * position int64_t theta_length = theta_scale_length * position_length; ggml_cann_pool_alloc theta_allocator(ctx.pool(), - theta_length * sizeof(float_t)); + theta_length * sizeof(float)); void* theta_buffer = theta_allocator.get(); - int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1}; - size_t theta_nb[GGML_MAX_DIMS]; - theta_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1]; - } + aclTensor* acl_theta_tensor = - ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS); aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor, - acl_theta_tensor); + acl_theta_tensor); // sin/cos ggml_cann_pool_alloc sin_allocator(ctx.pool(), - theta_length * sizeof(float_t)); + theta_length * sizeof(float)); void* sin_buffer = sin_allocator.get(); aclTensor* acl_sin_tensor = ggml_cann_create_tensor( - sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, + sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor); ggml_cann_pool_alloc cos_allocator(ctx.pool(), - theta_length * sizeof(float_t)); + theta_length * sizeof(float)); void* cos_buffer = cos_allocator.get(); aclTensor* acl_cos_tensor = ggml_cann_create_tensor( - cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, + cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor); + if (ext_factor != 0) { + attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + // attn_factor if (attn_factor != 1) { aclnn_muls(ctx, acl_sin_tensor, attn_factor, nullptr, true); aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true); } + int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1}; + size_t sin_reshape_nb[GGML_MAX_DIMS]; + sin_reshape_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; + } + aclTensor* acl_sin_repeat_tensor = + ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float), + sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); + aclTensor* acl_cos_repeat_tensor = + ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), + sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); + // repeat if (is_neox) { int64_t repeatsArray[] = {1, 1, 1, 2}; @@ -2256,9 +2499,17 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, num_repeats, output_size); } - // release + // Other layers use cache except first layer. + ctx.rope_cache.cached = true; + ctx.rope_cache.ext_factor = ext_factor; + ctx.rope_cache.theta_scale = theta_scale; + ctx.rope_cache.freq_scale = freq_scale; + ctx.rope_cache.attn_factor = attn_factor; + ctx.rope_cache.is_neox = is_neox; + ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor, - acl_theta_tensor, acl_sin_tensor, acl_cos_tensor, acl_theta_scale); + acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor, + acl_cos_repeat_tensor); } #ifdef __cplusplus @@ -2277,8 +2528,6 @@ aclnnStatus aclnnRotaryPositionEmbedding(void* workspace, #endif void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - // TODO: use ascendc - // Only test with LLAMA model. ggml_tensor* src0 = dst->src[0]; // input // param @@ -2301,8 +2550,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // TODO: n_dims <= ne0 GGML_ASSERT(n_dims == ne0); GGML_ASSERT(n_dims % 2 == 0); - // TODO: ext_factor != 0 - GGML_ASSERT(ext_factor == 0); const float theta_scale = powf(freq_base, -2.0f / n_dims); @@ -2312,28 +2559,22 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - // init cos/sin cache - ggml_cann_pool_alloc sin_allocator( - ctx.pool(), ne00 * ne02 * sizeof(float_t)); - ggml_cann_pool_alloc cos_allocator( - ctx.pool(), ne00 * ne02 * sizeof(float_t)); - void* sin_buffer = sin_allocator.get(); - void* cos_buffer = cos_allocator.get(); + // init ctx.rope_cos/rope_sin cache + aclnn_cache_init(ctx, dst, corr_dims, ext_factor, + theta_scale, freq_scale, attn_factor, is_neox); int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; size_t sin_reshape_nb[GGML_MAX_DIMS]; - sin_reshape_nb[0] = sizeof(float_t); + sin_reshape_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; } aclTensor* acl_sin_reshape_tensor = - ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_cos_reshape_tensor = - ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t), + ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); - aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor, - theta_scale, freq_scale, attn_factor, is_neox); aclTensor* acl_src = ggml_cann_create_tensor(src0); aclTensor* acl_dst = ggml_cann_create_tensor(dst); @@ -2347,7 +2588,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { void* minus_one_scale_buffer = nullptr; ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0)); ggml_cann_pool_alloc minus_one_scale_allocator( - ctx.pool(), sizeof(float_t) * src0->ne[0]); + ctx.pool(), sizeof(float) * src0->ne[0]); if (!is_neox) { // roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...] input_roll_buffer = roll_allocator.get(); @@ -2377,13 +2618,13 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; size_t minus_one_nb[GGML_MAX_DIMS]; - minus_one_nb[0] = sizeof(float_t); + minus_one_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; } acl_minus_one_tensor = aclnn_values( - ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0], - minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1); + ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0], + minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1); int64_t dim = 3; int64_t* index = new int64_t[src0->ne[0]]; for (int i = 0; i < src0->ne[0]; i++) { @@ -2411,22 +2652,22 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { minus_one_scale_buffer = minus_one_scale_allocator.get(); int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; size_t minus_one_nb[GGML_MAX_DIMS]; - minus_one_nb[0] = sizeof(float_t); + minus_one_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; } acl_minus_one_tensor = aclnn_values( - ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0], - minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1); + ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0], + minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1); // -1 * first half int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1}; size_t first_half_nb[GGML_MAX_DIMS]; - first_half_nb[0] = sizeof(float_t); + first_half_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1]; } aclTensor* acl_first_half_tensor = ggml_cann_create_tensor( - minus_one_scale_buffer, ACL_FLOAT, sizeof(float_t), first_half_ne, + minus_one_scale_buffer, ACL_FLOAT, sizeof(float), first_half_ne, first_half_nb, GGML_MAX_DIMS); bool inplace = true; float scale = -1; @@ -2466,28 +2707,28 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // TODO: ne0 != n_dims in mode2 } else if (src0->type == GGML_TYPE_F16) { size_t input_fp32_nb[GGML_MAX_DIMS]; - input_fp32_nb[0] = sizeof(float_t); + input_fp32_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1]; } ggml_cann_pool_alloc fp32_allocator1( - ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); + ctx.pool(), ggml_nelements(dst) * sizeof(float)); void* input_fp32_buffer1 = fp32_allocator1.get(); aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor( - input_fp32_buffer1, ACL_FLOAT, sizeof(float_t), dst->ne, + input_fp32_buffer1, ACL_FLOAT, sizeof(float), dst->ne, input_fp32_nb, GGML_MAX_DIMS); ggml_cann_pool_alloc fp32_allocator2( - ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); + ctx.pool(), ggml_nelements(dst) * sizeof(float)); void* input_fp32_buffer2 = fp32_allocator2.get(); aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor( - input_fp32_buffer2, ACL_FLOAT, sizeof(float_t), dst->ne, + input_fp32_buffer2, ACL_FLOAT, sizeof(float), dst->ne, input_fp32_nb, GGML_MAX_DIMS); ggml_cann_pool_alloc fp32_allocator( - ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); + ctx.pool(), ggml_nelements(dst) * sizeof(float)); output_fp32_buffer = fp32_allocator.get(); aclTensor* output_fp32_tensor = ggml_cann_create_tensor( - output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne, + output_fp32_buffer, ACL_FLOAT, sizeof(float), dst->ne, input_fp32_nb, GGML_MAX_DIMS); aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1); aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor, @@ -2584,8 +2825,6 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds aclIntArray *padding = aclCreateIntArray(paddingVal, 1); int64_t dilationVal[] = {1}; aclIntArray *dilation = aclCreateIntArray(dilationVal, 1); - bool transposed = true; - int64_t groups = 1; int8_t cubeMathType = 0; #ifdef ASCEND_310P @@ -2593,7 +2832,7 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds #endif GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, acl_weight, nullptr, stride, - padding, dilation, transposed, padding, groups, acl_dst, cubeMathType); + padding, dilation, true, padding, 1, acl_dst, cubeMathType); ggml_cann_release_resources(ctx, acl_weight, acl_dst, stride, padding, dilation); } @@ -2702,174 +2941,49 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){ */ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { //dst [M, K, N, 1] - ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] - ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 + ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] -> [D, M, K, 1] + ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 -> [D, 1, K, 1] ggml_tensor * ids = dst->src[2]; //ids [K, N] - GGML_TENSOR_BINARY_OP_LOCALS - - // copy index from npu to cpu - int64_t n_as = ne02; // A - int64_t n_ids = ids->ne[0]; // K - - std::vector ids_host(ggml_nbytes(ids)); - ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids), - ACL_MEMCPY_DEVICE_TO_HOST); - ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(src1->ne[3] == 1); + GGML_ASSERT(dst->ne[3] == 1); - char * src0_original = (char *) src0->data; - char * src1_original = (char *) src1->data; - char * dst_original = (char *) dst->data; - size_t ori_src0_nb[4] = {nb00, nb01, nb02, nb03}; + int64_t batch = src1->ne[2]; + GGML_ASSERT(batch == ids->ne[1]); - // src0 is F16, src1 is F32, dst is F32 - ggml_cann_pool_alloc src0_cast_allocator; - if (src0->type == GGML_TYPE_F16) { - src0_cast_allocator.alloc(ctx.pool(), sizeof(float) * ggml_nelements(src0)); - void* src0_cast_buf = src0_cast_allocator.get(); + ggml_cann_pool_alloc export_allocator(ctx.pool(), src0->ne[0] * src0->ne[1] * ids->ne[0] * ggml_element_size(src0)); + void* export_ptr = export_allocator.get(); + for (int64_t i = 0; i < batch; i++) { + aclTensor *select_index = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, i * ids->nb[1]); + aclTensor *export_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3); - size_t cast_nb[GGML_MAX_DIMS]; - cast_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - cast_nb[i] = cast_nb[i - 1] * src0->ne[i - 1]; + int64_t select_export_ne[] = {src0->ne[0], src0->ne[1], ids->ne[0]}; + size_t select_export_nb[3]; + select_export_nb[0] = src0->nb[0]; + for (int k = 1;k < 3; k++) { + select_export_nb[k] = select_export_nb[k-1] * select_export_ne[k-1]; } - aclTensor* acl_src0_f16 = ggml_cann_create_tensor(src0); - aclTensor* acl_cast = ggml_cann_create_tensor(src0_cast_buf, - ACL_FLOAT, sizeof(float), src0->ne, cast_nb, 4); - GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast); - ggml_cann_release_resources(ctx, acl_cast, acl_src0_f16); + aclTensor *select_export = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_export_ne, select_export_nb, 3); + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, export_weight, 0, select_index, select_export); - src0_original = (char *) src0_cast_buf; - memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb)); - } + int64_t select_transpose_ne[] = {select_export_ne[1], select_export_ne[0], select_export_ne[2]}; + size_t select_transpose_nb[] = {select_export_nb[1], select_export_nb[0], select_export_nb[2]}; + aclTensor *select_export_transpose = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_transpose_ne, select_transpose_nb, 3); -#ifdef ASCEND_310P - ggml_tensor src0_row = *src0; - ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; + int64_t active_tensor_ne[] = {src1->ne[0], 1, src1->ne[1]}; + size_t active_tensor_nb[] = {src1->nb[0], src1->nb[1], src1->nb[1]}; + aclTensor *active_tensor = ggml_cann_create_tensor(src1, active_tensor_ne, active_tensor_nb, 3, ACL_FORMAT_ND, i * src1->nb[2]); - if (src0->type == GGML_TYPE_F16) { - src0_row.type = GGML_TYPE_F32; - } + int64_t dst_ne[] = {dst->ne[0], 1, dst->ne[1]}; + size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[1]}; + aclTensor *acl_dst = ggml_cann_create_tensor(dst, dst_ne,dst_nb, 3, ACL_FORMAT_ND, i * dst->nb[2]); - // src0_row [D, M, 1, 1] weight without permute - src0_row.ne[2] = 1; - src0_row.ne[3] = 1; - src0_row.nb[0] = ori_src0_nb[0]; - src0_row.nb[1] = ori_src0_nb[1]; - src0_row.nb[2] = ori_src0_nb[1]; - src0_row.nb[3] = ori_src0_nb[1]; + GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, active_tensor, select_export_transpose, acl_dst, 2); - // src1_row [D, 1, 1, 1] -> input - src1_row.ne[1] = 1; - src1_row.ne[2] = 1; - src1_row.ne[3] = 1; - src1_row.nb[2] = nb11; - src1_row.nb[3] = nb11; - - // dst_row [M, 1, 1, 1] -> out - dst_row.ne[1] = 1; - dst_row.ne[2] = 1; - dst_row.ne[3] = 1; - dst_row.nb[2] = nb1; - dst_row.nb[3] = nb1; - - //create weight for one row - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - // expert index - int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - GGML_ASSERT(i02 >= 0 && i02 < n_as); - - // If B = 1 (broadcast), always use 0; otherwise, use id. - int64_t i11 = (ne11 == 1 ? 0 : id); - int64_t i12 = iid1; - - int64_t i1 = id; - int64_t i2 = i12; - - void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2]; - void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12; - void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2; - - src0_row.data = src0_tmp_ptr; - src1_row.data = src1_tmp_ptr; - dst_row.data = dst_tmp_ptr; - dst_row.src[0] = &src0_row; - dst_row.src[1] = &src1_row; - - ggml_cann_mul_mat(ctx, &dst_row); - } - } - return; -#endif - - std::vector src0_tensor_vec; - std::vector src1_tensor_vec; - std::vector dst_tensor_vec; - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - // src0_row [M, D] -> weight && permute - int64_t src0_ne[2] = {ne01, ne00}; - size_t src0_nb[2] = {ori_src0_nb[1], ori_src0_nb[0]}; - // src1_row [D, 1] -> input - int64_t src1_ne[2] = {ne10, 1}; - size_t src1_nb[2] = {nb10, nb11}; - // dst_row [M, 1] -> out - int64_t dst_ne[2] = {ne0, 1}; - size_t dst_nb[2] = {nb0, nb1}; - - // expert index - int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - GGML_ASSERT(i02 >= 0 && i02 < n_as); - - // If B = 1 (broadcast), always use 0; otherwise, use id. - int64_t i11 = (ne11 == 1 ? 0 : id); - int64_t i12 = iid1; - - int64_t i1 = id; - int64_t i2 = i12; - - void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2]; - void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12; - void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2; - - aclTensor* acl_src0 = ggml_cann_create_tensor(src0_tmp_ptr, - ACL_FLOAT, sizeof(float), - src0_ne, src0_nb, 2); - aclTensor* acl_src1 = ggml_cann_create_tensor(src1_tmp_ptr, - ACL_FLOAT, sizeof(float), - src1_ne, src1_nb, 2); - aclTensor* acl_dst = ggml_cann_create_tensor(dst_tmp_ptr, - ACL_FLOAT, sizeof(float), - dst_ne, dst_nb, 2); - - src0_tensor_vec.push_back(acl_src0); - src1_tensor_vec.push_back(acl_src1); - dst_tensor_vec.push_back(acl_dst); - } - } - - size_t GROUP_SIZE = 128; - // GroupedMatmulV3 required tensor_list.size < 128 - for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) { - // split and call GroupedMatmulV3 - size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size()); - std::vector src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end); - std::vector src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end); - std::vector dst_tensor_vec_split(dst_tensor_vec.begin() + i, dst_tensor_vec.begin() + end); - - aclTensorList* src0_tensor_list = aclCreateTensorList(src0_tensor_vec_split.data(), src0_tensor_vec_split.size()); - aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size()); - aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size()); - - GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV3, src1_tensor_list, src0_tensor_list, - nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list); - - ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list); + ggml_cann_release_resources(ctx, select_index, export_weight, select_export, active_tensor, acl_dst, select_export_transpose); } - return; } /** @@ -3018,11 +3132,38 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ - ggml_tensor* src0 = dst->src[0]; // q, fp32 - ggml_tensor* src1 = dst->src[1]; // k, fp16 - ggml_tensor* src2 = dst->src[2]; // v, fp16 + ggml_tensor* src0 = dst->src[0]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont) + ggml_tensor* src1 = dst->src[1]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont) + ggml_tensor* src2 = dst->src[2]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont) ggml_tensor* src3 = dst->src[3]; // mask, fp16 + // B, N, S, D (uncont) -> B, S, N, D (cont) + int64_t src0_bsnd_ne[GGML_MAX_DIMS]; + memcpy(src0_bsnd_ne, src0->ne, GGML_MAX_DIMS * sizeof(int64_t)); + size_t src0_bsnd_nb[GGML_MAX_DIMS]; + memcpy(src0_bsnd_nb, src0->nb, GGML_MAX_DIMS * sizeof(size_t)); + int64_t src1_bsnd_ne[GGML_MAX_DIMS]; + memcpy(src1_bsnd_ne, src1->ne, GGML_MAX_DIMS * sizeof(int64_t)); + size_t src1_bsnd_nb[GGML_MAX_DIMS]; + memcpy(src1_bsnd_nb, src1->nb, GGML_MAX_DIMS * sizeof(size_t)); + int64_t src2_bsnd_ne[GGML_MAX_DIMS]; + memcpy(src2_bsnd_ne, src2->ne, GGML_MAX_DIMS * sizeof(int64_t)); + size_t src2_bsnd_nb[GGML_MAX_DIMS]; + memcpy(src2_bsnd_nb, src2->nb, GGML_MAX_DIMS * sizeof(size_t)); + + auto transpose12 = [](int64_t* ne, size_t* nb) { + int64_t ne_tmp = ne[1]; + size_t nb_tmp = nb[1]; + ne[1] = ne[2]; + nb[1] = nb[2]; + ne[2] = ne_tmp; + nb[2] = nb_tmp; + }; + + transpose12(src0_bsnd_ne, src0_bsnd_nb); + transpose12(src1_bsnd_ne, src1_bsnd_nb); + transpose12(src2_bsnd_ne, src2_bsnd_nb); + float maxBias = 0.0f; float scaleValue = 1.0f; float logitSoftcap = 0.0f; @@ -3044,11 +3185,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ void* src0_f16_buffer = nullptr; if(ggml_cann_type_mapping(src0->type) != faDataType){ - aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0); + aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne, + src0_bsnd_nb, GGML_MAX_DIMS); src0_f16_buffer = src0_f16_allocator.alloc( ggml_nelements(src0) * faElemSize); - int64_t* src0_f16_ne = src0->ne; + int64_t* src0_f16_ne = src0_bsnd_ne; size_t src0_f16_nb[GGML_MAX_DIMS]; src0_f16_nb[0] = sizeof(uint16_t); for(int i = 1; i < GGML_MAX_DIMS; ++i){ @@ -3062,20 +3204,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType); ggml_cann_release_resources(ctx, acl_src0_f32_tensor); }else{ - acl_src0_f16_tensor = ggml_cann_create_tensor(src0); + acl_src0_f16_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne, + src0_bsnd_nb, GGML_MAX_DIMS); } // Step 2: create the acl tensors for src1 (Key), src2 (Value), // and the direct output from FusedInferAttention - acl_src1_f16_tensor = ggml_cann_create_tensor(src1); - acl_src2_f16_tensor = ggml_cann_create_tensor(src2); + acl_src1_f16_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, + src1_bsnd_nb, GGML_MAX_DIMS); + acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, + src2_bsnd_nb, GGML_MAX_DIMS); ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); void* out_f16_buffer = out_f16_allocator.alloc( ggml_nelements(dst) * faElemSize); - int64_t* out_f16_ne = src0->ne; + int64_t* out_f16_ne = src0_bsnd_ne; size_t out_f16_nb[GGML_MAX_DIMS]; out_f16_nb[0] = faElemSize; for(int i = 1; i < GGML_MAX_DIMS; ++i){ @@ -3089,88 +3234,81 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ // Step 3: create the PSEShift tensor if needed // this tensor is considered as mask (f16) in the llama.cpp - aclTensor* bcast_pse_tensor = nullptr; - int64_t bcast_pse_ne[GGML_MAX_DIMS]; - size_t bcast_pse_nb[GGML_MAX_DIMS]; ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); - void* bcast_pse_buffer = nullptr; - if(src3 != nullptr){ - bcast_pse_buffer = bcast_pse_allocator.alloc( - ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)); - - if(src0->ne[1] > 1){ - // Case 1: broadcast pse for prefill stage with multiple head - aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3); - bcast_pse_ne[0] = src3->ne[0]; - bcast_pse_ne[1] = src3->ne[1]; - bcast_pse_ne[2] = src0->ne[2]; - bcast_pse_ne[3] = src3->ne[3]; + // Construct the truncated pse tensor (common for prefill/decode) + int64_t trunc_pse_ne[GGML_MAX_DIMS] = { + src3->ne[0], // D + src0->ne[1], // S (number of Q tokens) + src3->ne[2], // mask N + src3->ne[3] // B + }; + size_t* trunc_pse_nb = src3->nb; + + aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( + src3->data, ACL_FLOAT16, sizeof(uint16_t), + trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS + ); + int64_t bcast_pse_ne[GGML_MAX_DIMS]; + size_t bcast_pse_nb[GGML_MAX_DIMS]; + bcast_pse_ne[0] = src3->ne[0]; // D + bcast_pse_ne[1] = src0->ne[1]; // S + bcast_pse_ne[2] = src0->ne[2]; // N (num_heads) + bcast_pse_ne[3] = src3->ne[3]; // B + if (maxBias == 0.0f) { + // When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2) + // Construct the bcast tensor (simulate repeat on the head dimension using stride=0) bcast_pse_nb[0] = sizeof(uint16_t); - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; - } + bcast_pse_nb[1] = bcast_pse_nb[0] * bcast_pse_ne[0]; + bcast_pse_nb[2] = 0; // <---- the head dimension shares the same data + bcast_pse_nb[3] = src3->nb[3]; bcast_pse_tensor = ggml_cann_create_tensor( - bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), - bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); - - int64_t repeats[] = {1, src0->ne[2], 1, 1}; - aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats); - - ggml_cann_release_resources(ctx, acl_mask_f16_tensor); - }else{ - // Case 2: trunc the first row and broadcast pse for decode stage with multiple head - int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]}; - size_t* trunc_pse_nb = src3->nb; - - aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( src3->data, ACL_FLOAT16, sizeof(uint16_t), - trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS); - - bcast_pse_ne[0] = src3->ne[0]; - bcast_pse_ne[1] = src0->ne[1]; - bcast_pse_ne[2] = src0->ne[2]; - bcast_pse_ne[3] = src3->ne[3]; + bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS + ); + ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor); + } else { bcast_pse_nb[0] = sizeof(uint16_t); - for(int i = 1; i < GGML_MAX_DIMS; ++i){ + for (int i = 1; i < GGML_MAX_DIMS; i++) { bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; } + void* bcast_pse_buffer = bcast_pse_allocator.alloc( + ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t) + ); + bcast_pse_tensor = ggml_cann_create_tensor( bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), - bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); + bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS + ); int64_t repeats[] = {1, src0->ne[2], 1, 1}; aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats); - ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor); - } - - // Compute the slope if needed. Derived from ggml_cann_softmax(). - if(maxBias != 0.0f){ // alibi + // Compute the slope if needed. Derived from ggml_cann_softmax(). const int64_t n_heads = src0->ne[2]; - ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float)); + ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t)); void* slope_buffer = slope_allocator.get(); - aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias); + aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias, GGML_TYPE_F16); int64_t slope_ne[] = {1, 1, n_heads, 1}; size_t slope_nb[GGML_MAX_DIMS]; - slope_nb[0] = sizeof(float); + slope_nb[0] = sizeof(uint16_t); for(int i = 1;ine[0]); // 1/sqrt(d) int64_t preTokens = 65535; int64_t nextTokens = 65535; - char layout[5] = {'B', 'N', 'S', 'D', 0}; + char layout[5] = {'B', 'S', 'N', 'D', 0}; int64_t sparseMode = 0; int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2; int64_t blockSize = 0; @@ -3224,32 +3362,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ); // Step 6: post-processing, permute and cast to f32 - - int64_t new_dim[] = {0, 2, 1, 3}; aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); - - if(ggml_cann_type_mapping(dst->type) != faDataType){ - ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool()); - perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); - void* perm_out_f16_buffer = perm_out_f16_allocator.get(); - - int64_t* perm_out_f16_ne = dst->ne; - size_t perm_out_f16_nb[GGML_MAX_DIMS]; - perm_out_f16_nb[0] = faElemSize; - for(int i = 1; i < GGML_MAX_DIMS; ++i){ - perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1]; - } - aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( - perm_out_f16_buffer, faDataType, faElemSize, - perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); - aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); - aclnn_cast(ctx, - acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); - ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor); - }else{ - // only need to permute - aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); - } + // TODO: when dst is fp16, don't need cast + aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); ggml_cann_release_resources(ctx, acl_src0_f16_tensor, acl_src1_f16_tensor, acl_src2_f16_tensor, diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 9d294f72b67..b707b843593 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -38,6 +38,7 @@ #include #include #include +#include #include "../include/ggml-cann.h" #include "../include/ggml.h" @@ -106,6 +107,7 @@ int32_t ggml_cann_get_device(); std::optional get_env(const std::string& name); bool parse_bool(const std::string& value); +int parse_integer(const std::string& value); /** * @brief Abstract base class for memory pools used by CANN. @@ -350,7 +352,7 @@ struct ggml_graph_node_properties { struct ggml_cann_graph { ~ggml_cann_graph() { if (graph != nullptr) { - aclmdlRIDestroy(graph); + ACL_CHECK(aclmdlRIDestroy(graph)); } } @@ -358,8 +360,105 @@ struct ggml_cann_graph { std::vector ggml_graph_properties; }; + +/** + * @brief LRU cache for managing ggml_cann_graph objects. + * + * This class maintains a list of shared_ptr to ggml_cann_graph objects + * and enforces a maximum capacity. It provides methods to push new graphs, + * move existing graphs to the front (most recently used), and clear the cache. + */ +struct ggml_cann_graph_lru_cache { + size_t capacity; /**< Maximum number of graphs in the cache. */ + + std::list cache_list; /**< List storing cached graphs as raw pointers. */ + + ggml_cann_graph_lru_cache() { + capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); + } + + /** + * @brief Push a new graph to the front of the cache. + * If the cache exceeds capacity, the least recently used graph is deleted. + * @param new_node Pointer to the new ggml_cann_graph to cache. + * Ownership is transferred to the cache (cache will delete it). + */ + void push(ggml_cann_graph* new_node) { + if (cache_list.size() >= capacity) { + ggml_cann_graph* old = cache_list.back(); + cache_list.pop_back(); + delete old; // free the old graph + } + cache_list.push_front(new_node); + } + + /** + * @brief Move an existing graph to the front of the cache. + * @param node Pointer to the ggml_cann_graph to move. + */ + void move_to_front(ggml_cann_graph* node) { + cache_list.remove(node); + cache_list.push_front(node); + } + + /** + * @brief Clear all graphs from the cache (also frees memory). + */ + void clear() { + for (auto ptr : cache_list) { + delete ptr; + } + cache_list.clear(); + } + + /** + * @brief Destructor that clears the cache and frees all cached graphs. + */ + ~ggml_cann_graph_lru_cache() { + clear(); + } +}; #endif // USE_ACL_GRAPH +struct ggml_cann_rope_cache { + ~ggml_cann_rope_cache() { + if(theta_scale_cache != nullptr) { + ACL_CHECK(aclrtFree(theta_scale_cache)); + } + if(sin_cache != nullptr) { + ACL_CHECK(aclrtFree(sin_cache)); + } + if(cos_cache != nullptr) { + ACL_CHECK(aclrtFree(cos_cache)); + } + } + + void* theta_scale_cache = nullptr; + int64_t theta_scale_length = 0; + // sin/cos cache, used only to accelerate first layer on each device + void* sin_cache = nullptr; + void* cos_cache = nullptr; + int64_t position_length = 0; + // Properties to check before reusing the sincos cache + bool cached = false; + float ext_factor = 0.0f; + float theta_scale = 0.0f; + float freq_scale = 0.0f; + float attn_factor = 0.0f; + bool is_neox = false; +}; + +struct ggml_cann_tensor_cache { + ~ggml_cann_tensor_cache() { + if(cache != nullptr) { + ACL_CHECK(aclrtFree(cache)); + } + } + + void* cache = nullptr; + int64_t size = 0; +}; + /** * @brief Context for managing CANN backend operations. */ @@ -370,11 +469,16 @@ struct ggml_backend_cann_context { aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */ #ifdef USE_ACL_GRAPH /// Cached CANN ACL graph used for executing the current ggml computation graph. - std::unique_ptr cann_graph; + ggml_cann_graph_lru_cache graph_lru_cache; + bool acl_graph_mode = true; #endif cann_task_queue task_queue; bool async_mode; - bool support_set_rows; + // Rope Cache + ggml_cann_rope_cache rope_cache; + // Constant Pool + ggml_cann_tensor_cache rms_norm_one_tensor_cache; + ggml_cann_tensor_cache rms_norm_zero_tensor_cache; aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */ @@ -390,14 +494,13 @@ struct ggml_backend_cann_context { async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or("")); GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF"); - - support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or("")); - GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF"); - - if (!support_set_rows) { - GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. " - "Falling back to eager mode.\n", __func__); - } +#ifdef USE_ACL_GRAPH + acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on")); + GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", + __func__, device, + acl_graph_mode ? "GRAPH" : "EAGER", + acl_graph_mode ? "acl graph enabled" : "acl graph disabled"); +#endif } /** @@ -423,7 +526,10 @@ struct ggml_backend_cann_context { */ aclrtStream stream(int stream) { if (streams[stream] == nullptr) { - ggml_cann_set_device(device); + // If the device is not set here, destroying the stream later may cause a mismatch + // between the thread contexts where the stream was created and destroyed. + // However, I printed the device_id, thread_id, and stream, and they are all consistent. + ACL_CHECK(aclrtSetDevice(device)); ACL_CHECK(aclrtCreateStream(&streams[stream])); } return streams[stream]; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index cb8af42ebf9..b51b554e752 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -75,13 +75,12 @@ * @param device The device ID to set. */ void ggml_cann_set_device(const int32_t device) { - // TODO: uncomment these lines after empty context has fixed. - // int current_device; - // ACL_CHECK(aclrtGetDevice(¤t_device)); + int current_device = -1; + aclrtGetDevice(¤t_device); - // if (device == current_device) { - // return; - // } + if (device == current_device) { + return; + } ACL_CHECK(aclrtSetDevice(device)); } @@ -116,6 +115,24 @@ bool parse_bool(const std::string& value) { return valid_values.find(value) != valid_values.end(); } +/** + * @brief Parse a string as an integer, returning 0 if invalid. + * + * This function attempts to convert the input string `value` to an `int`. + * If the string is not a valid integer or is out of the `int` range, + * it returns 0. + * + * @param value The string to parse. + * @return The parsed integer, or 0 if conversion fails. + */ +int parse_integer(const std::string& value) { + try { + return std::stoi(value); + } catch (...) { + return 0; + } +} + /** * @brief Initialize the CANN device information. * @@ -1116,30 +1133,65 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor( return GGML_STATUS_SUCCESS; } -// ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed -namespace { - void* g_nz_workspace = nullptr; - size_t g_nz_workspace_allocated = 0; +/** + * @brief Workspace for caching NZ buffers per device. + * + * This struct manages a device buffer used in NZ computations. It supports + * allocation, reallocation, and clearing of cached memory. The struct is + * designed to be used with a global array, one per device. + */ +struct ggml_cann_nz_workspace { + void* ptr; // Pointer to allocated device buffer + size_t allocated; // Size of currently allocated buffer in bytes - void release_nz_workspace() { - if (g_nz_workspace) { - aclrtFree(g_nz_workspace); - g_nz_workspace = nullptr; - g_nz_workspace_allocated = 0; + /** + * @brief Constructor. Initializes the workspace with no allocated memory. + */ + ggml_cann_nz_workspace() : ptr(nullptr), allocated(0) {} + + /** + * @brief Free cached memory and reset the workspace. + * + * If a buffer has been allocated, this function releases it using + * aclrtFree and resets internal state. + */ + void clear() { + if (ptr) { + ACL_CHECK(aclrtFree(ptr)); + ptr = nullptr; + allocated = 0; } } - void relloc_nz_workspace(size_t new_size) { - if (new_size > g_nz_workspace_allocated) { - if (g_nz_workspace) { - aclrtFree(g_nz_workspace); - g_nz_workspace = nullptr; + /** + * @brief Allocate or reallocate the workspace buffer. + * + * If the requested size is larger than the currently allocated size, + * the old buffer will be freed and a new buffer of the requested size + * will be allocated on the device. + * + * @param new_size Size in bytes to allocate for the workspace. + */ + void realloc(size_t new_size) { + if (new_size > allocated) { + clear(); + ACL_CHECK(aclrtMalloc(&ptr, new_size, ACL_MEM_MALLOC_HUGE_FIRST)); + allocated = new_size; } - ACL_CHECK(aclrtMalloc(&g_nz_workspace, new_size, ACL_MEM_MALLOC_HUGE_FIRST)); - g_nz_workspace_allocated = new_size; } - } -} + + /** + * @brief Get the device buffer pointer. + * + * @return Pointer to the allocated buffer, or nullptr if not allocated. + */ + void* get() const { return ptr; } +}; + +/** + * @brief Global array of NZ workspaces, one per device. + */ +static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES]; /** * @brief Convert tensor weights to NZ format using Ascend CANN API. @@ -1149,13 +1201,13 @@ namespace { * improve performance on certain hardware. * * @param tensor Pointer to the input ggml_tensor containing the weights. - * @param data Pointer to the raw data buffer for the tensor weights. * @param offset Byte offset within the tensor data buffer where weights start. + * @param device device id. * * @note The workspace buffer used in this function is managed globally and reused * across calls. This reduces overhead from repeated memory allocation and deallocation. */ -static void weight_format_to_nz(ggml_tensor *tensor, const void *data, size_t offset) { +static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device) { aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset); uint64_t workspaceSize = 0; @@ -1165,7 +1217,9 @@ static void weight_format_to_nz(ggml_tensor *tensor, const void *data, size_t of ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, &workspaceSize, &executor)); // Avoid frequent malloc/free of the workspace. - relloc_nz_workspace(workspaceSize); + g_nz_workspaces[device].realloc(workspaceSize); + + void* g_nz_workspace = g_nz_workspaces[device].get(); ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr)); ACL_CHECK(aclDestroyTensor(weightTransposed)); @@ -1196,14 +1250,14 @@ static void ggml_backend_cann_buffer_set_tensor( // Why aclrtSynchronizeDevice? // Only check env once. - static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("")); + static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); if (!need_transform(tensor->type)) { ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) { GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[3] == 1); - weight_format_to_nz(tensor, data, offset); + weight_format_to_nz(tensor, offset, ctx->device); } } else { void *transform_buffer = malloc(size); @@ -1279,6 +1333,10 @@ static bool ggml_backend_cann_buffer_cpy_tensor( ACL_MEMCPY_DEVICE_TO_DEVICE)); return true; } else { +#ifdef ASCEND_310P + // TODO: Support 310p P2P copy + return false; +#endif // Different device but can access by peer. int32_t canAccessPeer = 0; ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device, @@ -1439,7 +1497,7 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size( int64_t ne0 = tensor->ne[0]; // Only check env once. - static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("")); + static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); // last line must bigger than 32, because every single op deal at // least 32 bytes. @@ -2000,6 +2058,8 @@ static bool ggml_backend_cann_cpy_tensor_async( GGML_ASSERT(ggml_backend_is_cann(backend_src) || ggml_backend_is_cann(backend_dst)); + GGML_ASSERT(!is_matmul_weight((const ggml_tensor*)src)); + if (!ggml_backend_buffer_is_cann(src->buffer) || !ggml_backend_buffer_is_cann(dst->buffer)) { return false; @@ -2020,6 +2080,10 @@ static bool ggml_backend_cann_cpy_tensor_async( return true; } if (backend_src != backend_dst) { +#ifdef ASCEND_310P + // TODO: Support 310p P2P copy + return false; +#endif ggml_backend_cann_buffer_context* buf_ctx_src = (ggml_backend_cann_buffer_context*)buf_src->context; ggml_backend_cann_buffer_context* buf_ctx_dst = @@ -2036,7 +2100,6 @@ static bool ggml_backend_cann_cpy_tensor_async( } // need open both directions for memcpyasync between devices. - ggml_cann_set_device(cann_ctx_dst->device); ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0)); ggml_cann_set_device(cann_ctx_src->device); ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0)); @@ -2046,9 +2109,17 @@ static bool ggml_backend_cann_cpy_tensor_async( ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, cann_ctx_src->stream())); - - //TODO: workaround for Event didn`t work here. - aclrtSynchronizeStream(cann_ctx_src->stream()); + // record event on src stream after the copy + // TODO: this event is not effective with acl graph mode, change to use aclrtSynchronizeStream + // if (!cann_ctx_src->copy_event) { + // ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC)); + // } + // ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream())); + + // // wait on dst stream for the copy to complete + // ggml_cann_set_device(cann_ctx_dst->device); + // ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event)); + ACL_CHECK(aclrtSynchronizeStream(cann_ctx_src->stream())); } else { // src and dst are on the same backend ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, @@ -2077,30 +2148,52 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) { #ifdef USE_ACL_GRAPH /** - * @brief Populate the internal CANN graph node properties from the ggml computation graph. + * @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph. + * + * This function creates a new ggml_cann_graph object and fills its node properties + * (operation type, dimensions, strides, input sources, and operation parameters) + * based on the current ggml computation graph. + * + * Each node in the ggml graph is mapped to a property entry in the new CANN graph: + * - node address + * - operation type + * - shape (ne) and strides (nb) + * - source tensor addresses + * - operation parameters * - * This function copies all node attributes (operation type, dimensions, strides, input sources, - * and operation parameters) into the cached CANN graph structure for later reuse or comparison. + * After initialization, the new graph is pushed into the LRU cache owned by the + * CANN backend context. The cache takes ownership of the graph and manages its + * lifetime (including deletion upon eviction). * - * @param cann_ctx The CANN backend context. - * @param cgraph The ggml computational graph. + * @param cann_ctx The CANN backend context containing the graph cache. + * @param cgraph The current ggml computation graph. */ -static void set_ggml_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { - for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) { +static void add_lru_matched_graph_node_properties( + ggml_backend_cann_context * cann_ctx, + ggml_cgraph * cgraph) { + // Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache). + ggml_cann_graph * new_graph = new ggml_cann_graph(); + new_graph->ggml_graph_properties.resize(cgraph->n_nodes); + + for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) { ggml_tensor * node = cgraph->nodes[node_idx]; - cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_address = node->data; - cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_op = node->op; + auto & prop = new_graph->ggml_graph_properties[node_idx]; - for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { - cann_ctx->cann_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim]; - cann_ctx->cann_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim]; - } - for (int src = 0; src < GGML_MAX_SRC; src++) { - cann_ctx->cann_graph->ggml_graph_properties[node_idx].src_address[src] = - node->src[src] ? node->src[src]->data : nullptr; + prop.node_address = node->data; + prop.node_op = node->op; + + std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne); + std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb); + + for (int src = 0; src < GGML_MAX_SRC; ++src) { + prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr; } - memcpy(cann_ctx->cann_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS); + + memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS); } + + // Insert into the LRU cache (cache takes ownership and will delete it when evicted). + cann_ctx->graph_lru_cache.push(new_graph); } /** @@ -2145,30 +2238,45 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } /** - * @brief Determine if the CANN graph needs to be rebuilt due to graph changes. + * @brief Check whether there is a cached CANN graph that matches the current ggml graph. + * + * This function iterates through the cached CANN graphs stored in the LRU cache and + * compares them against the given ggml computation graph. A match requires that the + * number of nodes is the same and that each node’s properties (operation type, + * dimensions, strides, inputs, and operation parameters) are identical. * - * This checks whether the number or properties of ggml graph nodes have changed - * compared to the last captured CANN graph. If so, the CANN graph must be re-captured. + * If a matching graph is found, it is promoted to the front of the LRU cache and the + * function returns true. Otherwise, the function returns false, indicating that a new + * CANN graph needs to be captured. * - * @param cann_ctx The CANN backend context. + * @param cann_ctx The CANN backend context containing the graph cache. * @param cgraph The current ggml computation graph. - * @return true if an update is required; false otherwise. - */ -static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { - // The number of nodes is different, so the graph needs to be reconstructed. - if (cann_ctx->cann_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { - cann_ctx->cann_graph->ggml_graph_properties.resize(cgraph->n_nodes); - return true; - } + * @return true if a matching cached graph exists; false otherwise. + */ +static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { + ggml_cann_graph_lru_cache &lru_cache = cann_ctx->graph_lru_cache; + for (auto &graph_ptr : lru_cache.cache_list) { + // Skip graphs with a different number of nodes. + if (graph_ptr->ggml_graph_properties.size() != static_cast(cgraph->n_nodes)) { + continue; + } + + // Check if all nodes match. + bool all_match = true; + for (int i = 0; i < cgraph->n_nodes; ++i) { + if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) { + all_match = false; + break; + } + } - // The number of nodes is the same; iterate over each node to check whether they match. - for (int i = 0; i < cgraph->n_nodes; i++) { - bool has_matching_properties = ggml_graph_node_has_matching_properties( - cgraph->nodes[i], &cann_ctx->cann_graph->ggml_graph_properties[i]); - if(!has_matching_properties) { + if (all_match) { + // update cache_list && renturn graph_ptr + lru_cache.move_to_front(graph_ptr); return true; } } + return false; } #endif // USE_ACL_GRAPH @@ -2187,17 +2295,13 @@ static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, * @param cann_graph_update_required Whether graph capture is needed due to graph changes. */ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph, - bool & use_cann_graph, bool & cann_graph_update_required) { + bool & use_cann_graph, bool & cann_graph_update_required) { #ifdef USE_ACL_GRAPH + ggml_cann_graph* matched_graph = cann_ctx->graph_lru_cache.cache_list.front(); if (use_cann_graph && cann_graph_update_required) { - if (cann_ctx->cann_graph->graph != nullptr) { - ACL_CHECK(aclmdlRIDestroy(cann_ctx->cann_graph->graph)); - cann_ctx->cann_graph->graph = nullptr; - } ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL)); } #endif // USE_ACL_GRAPH - // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // With the use of CANN graphs, the execution will be performed by the graph launch. if (!use_cann_graph || cann_graph_update_required) { @@ -2218,12 +2322,12 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx #ifdef USE_ACL_GRAPH if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture - ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->cann_graph->graph)); + ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph)); } if (use_cann_graph) { // Execute graph - ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream())); + ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream())); } #endif // USE_ACL_GRAPH } @@ -2246,30 +2350,46 @@ static enum ggml_status ggml_backend_cann_graph_compute( ggml_backend_cann_context* cann_ctx = (ggml_backend_cann_context*)backend->context; ggml_cann_set_device(cann_ctx->device); - release_nz_workspace(); + g_nz_workspaces[cann_ctx->device].clear(); + + // calculate rope cache for fist layer in current device. + cann_ctx->rope_cache.cached = false; + #ifdef USE_ACL_GRAPH bool use_cann_graph = true; bool cann_graph_update_required = false; - // check environment LLAMA_SET_ROWS - if (!cann_ctx->support_set_rows) { + static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); + if (!prefill_use_graph) { + // Do not use acl_graph for prefill. + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + // TODO: Optimize here. Currently, we can only + // get seq_len by FA's input. + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + // Q -> src[0], shape: [B, S, N, D] + use_cann_graph = (node->src[0]->ne[1] == 1); + break; + } + } + } + + if (!cann_ctx->acl_graph_mode) { use_cann_graph = false; } if (use_cann_graph) { - if (cann_ctx->cann_graph == nullptr) { - cann_ctx->cann_graph.reset(new ggml_cann_graph()); - cann_graph_update_required = true; + // If no matching graph is found, the graph needs to be recaptured. + cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph); + if (cann_graph_update_required) { + // If no matching graph is found, add a new ACL graph. + add_lru_matched_graph_node_properties(cann_ctx, cgraph); } - - cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph); - set_ggml_graph_node_properties(cann_ctx, cgraph); } #else bool use_cann_graph = false; bool cann_graph_update_required = false; #endif // USE_ACL_GRAPH - evaluate_and_capture_cann_graph( cann_ctx, cgraph, @@ -2336,7 +2456,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_0: #ifdef ASCEND_310P - // Q4 && Q8 per group is not suppor on 310p device + // Q4 && Q8 per group is not support on 310p device return false; #endif // only support contiguous for quantized types. @@ -2354,7 +2474,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_0: #ifdef ASCEND_310P - // Q4 && Q8 per group is not suppor on 310p device + // Q4 && Q8 per group is not support on 310p device return false; #endif // only support contiguous for quantized types. @@ -2405,16 +2525,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } case GGML_OP_ROPE: { // TODO: with ops-test v == 1 - float ext_factor = 0.0f; - memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float)); // TODO: n_dims <= ne0 if (op->src[0]->ne[0] != op->op_params[1]) { return false; } - // TODO: ext_factor != 0 - if (ext_factor != 0) { - return false; - } const int mode = ((const int32_t *) op->op_params)[2]; if (mode & GGML_ROPE_TYPE_MROPE) { @@ -2423,10 +2537,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, if (mode & GGML_ROPE_TYPE_VISION) { return false; } - +#ifdef ASCEND_310P if(!ggml_is_contiguous(op->src[0])){ return false; } +#endif return true; } case GGML_OP_UPSCALE: { @@ -2488,15 +2603,17 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_ARGMAX: case GGML_OP_COS: case GGML_OP_SIN: - case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_LOG: case GGML_OP_MEAN: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_CONV_TRANSPOSE_1D: + // TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255. + return (op->src[0]->ne[0] - 1) <= 255; case GGML_OP_SCALE: float bias; - memcpy(&bias, (float*)op->op_params + 1, sizeof(float)); + memcpy(&bias, (const float *)(op->op_params) + 1, sizeof(float)); return bias == 0.0f; // TODO: support bias != 0.0f case GGML_OP_SOFT_MAX: // TODO: support attention sinks [TAG_ATTN_SINKS] @@ -2505,6 +2622,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } return true; case GGML_OP_FLASH_ATTN_EXT:{ +#ifdef ASCEND_310P + // FA not support on 310p device + return false; +#endif // derived from [ggml-cuda.cu] if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){ return false; @@ -2523,15 +2644,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // different head sizes of K and V are not supported yet return false; } - if (op->src[0]->ne[0] == 192) { - return false; - } - if (op->src[0]->ne[0] == 576) { - // DeepSeek MLA + if (op->src[0]->ne[0] % 16 != 0) { + // TODO: padding to support return false; } float logitSoftcap = 0.0f; - memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float)); + memcpy(&logitSoftcap, (const float *)(op->op_params) + 2, sizeof(float)); if(logitSoftcap != 0.0f) { return false; } @@ -2638,6 +2756,7 @@ static const ggml_backend_i ggml_backend_cann_interface = { /* .graph_compute = */ ggml_backend_cann_graph_compute, /* .event_record = */ ggml_backend_cann_event_record, /* .event_wait = */ ggml_backend_cann_event_wait, + /* .graph_optimize = */ NULL, }; /** diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index ce0a3e1285e..36990575075 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -224,7 +224,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos) if (NOT ${feature_pos} EQUAL -1) - message(STATUS "ARM feature ${feature} enabled") + # Special handling for MATMUL_INT8 when machine doesn't support i8mm + if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm) + message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm") + list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8) + else() + message(STATUS "ARM feature ${feature} enabled") + endif() endif() endforeach() endif() @@ -433,15 +439,22 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/arch/riscv/quants.c ggml-cpu/arch/riscv/repack.cpp ) - if (GGML_RVV) - if (GGML_XTHEADVECTOR) - list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d) - elseif (GGML_RV_ZFH) - list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d) - else() - list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) + set(MARCH_STR "rv64gc") + if (GGML_RV_ZFH) + string(APPEND MARCH_STR "_zfh") + endif() + if (GGML_XTHEADVECTOR) + string(APPEND MARCH_STR "_xtheadvector") + elseif (GGML_RVV) + string(APPEND MARCH_STR "_v") + if (GGML_RV_ZVFH) + string(APPEND MARCH_STR "_zvfh") endif() endif() + if (GGML_RV_ZICBOP) + string(APPEND MARCH_STR "_zicbop") + endif() + list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) elseif (GGML_SYSTEM_ARCH STREQUAL "s390x") message(STATUS "s390x detected") list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c) @@ -450,7 +463,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # TODO: Separation to determine activation of VX/VXE/VXE2 if (${S390X_M} MATCHES "8561|8562") - set(GGML_NNPA OFF) message(STATUS "z15 target") list(APPEND ARCH_FLAGS -march=z15) elseif (${S390X_M} MATCHES "3931") @@ -472,11 +484,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND ARCH_FLAGS -mvx -mzvector) list(APPEND ARCH_DEFINITIONS GGML_VXE) endif() - - if (GGML_NNPA) - message(STATUS "NNPA enabled") - list(APPEND ARCH_DEFINITIONS GGML_NNPA) - endif() elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm") message(STATUS "Wasm detected") list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c) @@ -497,9 +504,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.11.0") + set(KLEIDIAI_COMMIT_TAG "v1.13.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2") + set(KLEIDIAI_ARCHIVE_MD5 "d82a8de939d9814621a5ba23907bdac1") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) @@ -555,6 +562,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) @@ -576,7 +584,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c - ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c) + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c + ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S) set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2") endif() diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 258857b0075..867e158dcaa 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -7,7 +7,7 @@ #include "ggml-cpu.h" #include "traits.h" -#if defined(__gnu_linux__) +#if defined(__linux__) #include #include #endif @@ -186,7 +186,7 @@ static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_ty #define XFEATURE_XTILEDATA 18 static bool ggml_amx_init() { -#if defined(__gnu_linux__) +#if defined(__linux__) if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) { fprintf(stderr, "AMX is not ready to be used!\n"); return false; @@ -194,6 +194,8 @@ static bool ggml_amx_init() { return true; #elif defined(_WIN32) return true; +#else + return false; #endif } diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index f476127995b..373408a9c09 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -73,7 +73,6 @@ #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K -#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -151,8 +150,6 @@ #elif defined(__s390x__) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K -#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0 -#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c index 49aae7a23bb..d3dfd049eaf 100644 --- a/ggml/src/ggml-cpu/arch/powerpc/quants.c +++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -278,6 +278,72 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_MXFP4 == 0); + static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); + + const block_mxfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK_MXFP4; + + int ib = 0; + float sumf = 0; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char vshift4 = vec_splats((unsigned char)4); + vector float vsumf0 = vec_splats(0.0f); + + vector signed char kv = vec_xl(0, (const signed char *)kvalues_mxfp4); + +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d) * + GGML_E8M0_TO_FP32_HALF(x[ib].e)); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed char qxs = (vector signed char)vec_xl(0, x[ib].qs); + + vector unsigned char lo_nibbles = (vector unsigned char)vec_and(qxs, lowMask); + vector unsigned char hi_nibbles = (vector unsigned char)vec_sr(qxs, vshift4); + + vector signed char q4x0 = vec_perm(kv, kv, lo_nibbles); + vector signed char q4x1 = vec_perm(kv, kv, hi_nibbles); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi0 = vec_sum4s(qv1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vyd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + sumf = vec_extract(vsumf0, 0); + *s = sumf; +#else + UNUSED(x); + UNUSED(y); + UNUSED(ib); + UNUSED(sumf); + ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index 6c74417c90c..ee41a3502e8 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -1270,29 +1270,40 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - int tmp, tmp2, sumi; + float ftmp, ft2; + const uint8_t * restrict q40; + const uint8_t * restrict q41; + const uint8_t * restrict q42; + const uint8_t * restrict q43; + const int8_t * restrict q80; + const int8_t * restrict q81; + const int8_t * restrict q82; + const int8_t * restrict q83; + int s0, s1, s2, s3; + __asm__ __volatile__( - "vsetivli zero, 12, e8, m1\n\t" - "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]} - "vsetivli zero, 4, e32, m1\n\t" + "li %[s1], 8\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vle32.v v1, (%[s6b])\n\t" + "vslide1down.vx v1, v1, zero\n\t" + "vmv.v.x v16, zero\n\t" "vslidedown.vi v2, v1, 2\n\t" "vmv1r.v v3, v2\n\t" "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} - "vsetivli zero, 2, e32, m1\n\t" + "vsetivli zero, 2, e32, m1, ta, ma\n\t" "vmv.v.i v4, 4\n\t" "vand.vx v8, v1, %[kmask1]\n\t" "vslide1up.vx v5, v4, zero\n\t" // {0, 4} "vsrl.vi v6, v1, 6\n\t" "vsrl.vv v7, v2, v5\n\t" + "vsse32.v v8, (%[utmp]), %[s1]\n\t" "vand.vx v0, v6, %[kmask3]\n\t" "vand.vx v2, v7, %[kmask2]\n\t" "vsll.vi v6, v0, 4\n\t" - "li %[t2], 8\n\t" - "addi %[t1], %[utmp], 4\n\t" + "addi %[s0], %[utmp], 4\n\t" "vor.vv v1, v6, v2\n\t" - "vsse32.v v8, (%[utmp]), %[t2]\n\t" - "vsse32.v v1, (%[t1]), %[t2]\n\t" - "vsetivli zero, 8, e16, m1\n\t" + "vsse32.v v1, (%[s0]), %[s1]\n\t" + "vsetivli zero, 8, e16, m1, ta, ma\n\t" "vle32.v v2, (%[bsums])\n\t" "vnsrl.wi v0, v2, 0\n\t" "vnsrl.wi v1, v2, 16\n\t" @@ -1300,13 +1311,131 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi "vle8.v v3, (%[mins])\n\t" "vzext.vf2 v4, v3\n\t" "vwmul.vv v6, v4, v2\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vredsum.vs v0, v6, v16\n\t" + "vredsum.vs v0, v7, v0\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfmv.f.s %[ftmp], v0\n\t" + "vsetivli zero, 16, e8, m1, ta, ma\n\t" + "vle8.v v0, (%[xs])\n\t" + "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t" + "addi %[q40], %[xs], 64\n\t" + "addi %[q41], %[xs], 16\n\t" + "addi %[q42], %[xs], 32\n\t" + "addi %[q43], %[xs], 48\n\t" + "addi %[q80], %[ys], 64\n\t" + "vle8.v v1, (%[q41])\n\t" + "vle8.v v2, (%[q42])\n\t" + "addi %[q81], %[ys], 16\n\t" + "addi %[q41], %[q41], 64\n\t" + "addi %[q82], %[ys], 32\n\t" + "vle8.v v3, (%[q43])\n\t" + "vle8.v v8, (%[ys])\n\t" + "addi %[q42], %[q42], 64\n\t" + "addi %[q83], %[ys], 48\n\t" + "addi %[q43], %[q43], 64\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vle8.v v9, (%[q81])\n\t" + "vle8.v v10, (%[q82])\n\t" + "vand.vi v0, v0, 0xF\n\t" + "addi %[q81], %[q81], 64\n\t" + "vsrl.vi v5, v1, 4\n\t" + "addi %[q82], %[q82], 64\n\t" + "vle8.v v11, (%[q83])\n\t" + "vle8.v v12, (%[q80])\n\t" + "vand.vi v1, v1, 0xF\n\t" + "addi %[q83], %[q83], 64\n\t" + "vsrl.vi v6, v2, 4\n\t" + "addi %[q80], %[q80], 64\n\t" + "vle8.v v13, (%[q81])\n\t" + "vle8.v v14, (%[q82])\n\t" + "vand.vi v2, v2, 0xF\n\t" + "addi %[q81], %[q81], 64\n\t" + "vsrl.vi v7, v3, 4\n\t" + "addi %[q82], %[q82], 64\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vle8.v v15, (%[q83])\n\t" + "vle8.v v0, (%[q40])\n\t" + "vand.vi v3, v3, 0xF\n\t" + "addi %[q83], %[q83], 64\n\t" + "vwmul.vv v24, v2, v12\n\t" + "vwmul.vv v20, v4, v10\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmacc.vv v16, v1, v9\n\t" + "vle8.v v1, (%[q41])\n\t" + "vle8.v v2, (%[q42])\n\t" + "vwmacc.vv v24, v3, v13\n\t" + "vwmacc.vv v20, v5, v11\n\t" + "vwmacc.vv v28, v7, v15\n\t" + "addi %[q40], %[q80], 64\n\t" + "addi %[q41], %[q81], 64\n\t" + "vle8.v v3, (%[q43])\n\t" + "vle8.v v8, (%[q80])\n\t" + "addi %[q42], %[q82], 64\n\t" + "addi %[q43], %[q83], 64\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vle8.v v9, (%[q81])\n\t" + "vle8.v v10, (%[q82])\n\t" + "vand.vi v0, v0, 0xF\n\t" + "vsrl.vi v5, v1, 4\n\t" + "vsrl.vi v7, v3, 4\n\t" + "vand.vi v3, v3, 0xF\n\t" + "vle8.v v11, (%[q83])\n\t" + "vle8.v v12, (%[q40])\n\t" + "vand.vi v1, v1, 0xF\n\t" + "vsrl.vi v6, v2, 4\n\t" + "vand.vi v2, v2, 0xF\n\t" + "vwmul.vv v18, v0, v8\n\t" + "vle8.v v13, (%[q41])\n\t" + "vle8.v v14, (%[q42])\n\t" + "vwmul.vv v26, v2, v12\n\t" + "vwmul.vv v22, v4, v10\n\t" + "vwmul.vv v30, v6, v14\n\t" + "vwmacc.vv v18, v1, v9\n\t" + "vle8.v v15, (%[q43])\n\t" + "vwmacc.vv v26, v3, v13\n\t" + "vwmacc.vv v22, v5, v11\n\t" + "vwmacc.vv v30, v7, v15\n\t" "vmv.v.x v0, zero\n\t" - "vsetivli zero, 8, e32, m2\n\t" - "vredsum.vs v0, v6, v0\n\t" - "vmv.x.s %[sumi], v0" - : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi) - : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) - , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1) + "vsetivli zero, 16, e16, m2, ta, ma\n\t" + "vwredsum.vs v4, v16, v0\n\t" + "lbu %[s0], 0(%[scale])\n\t" + "vwredsum.vs v5, v20, v0\n\t" + "lbu %[s1], 1(%[scale])\n\t" + "vwredsum.vs v6, v24, v0\n\t" + "lbu %[s2], 2(%[scale])\n\t" + "vwredsum.vs v7, v28, v0\n\t" + "lbu %[s3], 3(%[scale])\n\t" + "vwredsum.vs v8, v18, v0\n\t" + "lbu %[q40], 4(%[scale])\n\t" + "vwredsum.vs v9, v22, v0\n\t" + "lbu %[q41], 5(%[scale])\n\t" + "vwredsum.vs v10, v26, v0\n\t" + "lbu %[q42], 6(%[scale])\n\t" + "vwredsum.vs v11, v30, v0\n\t" + "lbu %[q43], 7(%[scale])\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vmul.vx v0, v4, %[s0]\n\t" + "vmul.vx v1, v8, %[q40]\n\t" + "vmacc.vx v0, %[s1], v5\n\t" + "vmacc.vx v1, %[q41], v9\n\t" + "vmacc.vx v0, %[s2], v6\n\t" + "vmacc.vx v1, %[q42], v10\n\t" + "vmacc.vx v0, %[s3], v7\n\t" + "vmacc.vx v1, %[q43], v11\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfcvt.f.x.v v1, v1\n\t" + "vfmv.f.s %[ft2], v0\n\t" + "vfmv.f.s %[ftmp], v1\n\t" + "fadd.s %[ft2], %[ft2], %[ftmp]\n\t" + "fmadd.s %[sumf], %[d], %[ft2], %[sumf]" + : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2) + , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3) + , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43) + , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83) + : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales) + , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) + , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin) , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) : "memory" , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" @@ -1314,59 +1443,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" ); - sumf -= dmin * sumi; - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - sumi = 0; - const uint8_t * scale = scales; - - for (int j = 0; j < QK_K/128; ++j) { - int vl128 = 128, vl64 = 64, vl32 = 32; - __asm__ __volatile__( - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v8, (%[q8])\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vle8.v v0, (%[q4])\n\t" - "vsrl.vi v4, v0, 4\n\t" - "vand.vi v0, v0, 0xF\n\t" - "vsetvli zero, %[vl32], e8, m2\n\t" - "vwmul.vv v28, v6, v14\n\t" - "vwmul.vv v20, v4, v10\n\t" - "vwmul.vv v24, v2, v12\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vle8.v v2, (%[scale])\n\t" - "vmv.v.x v0, zero\n\t" - "vzext.vf4 v1, v2\n\t" - "vsetvli zero, %[vl32], e16, m4\n\t" - "vwredsum.vs v6, v24, v0\n\t" - "vwredsum.vs v7, v28, v0\n\t" - "vwredsum.vs v4, v16, v0\n\t" - "vwredsum.vs v5, v20, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vslideup.vi v6, v7, 1\n\t" - "vslideup.vi v4, v5, 1\n\t" - "vslideup.vi v4, v6, 2\n\t" - "vmul.vv v8, v4, v1\n\t" - "vredsum.vs v0, v8, v0\n\t" - "vmv.x.s %[tmp], v0\n\t" - "add %[sumi], %[sumi], %[tmp]" - : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi) - : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32) - , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - - q4 += 64; q8 += 128; scale += 4; - } - - sumf += d * sumi; } break; default: @@ -1693,6 +1769,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi case 128: for (int i = 0; i < nb; ++i) { + __builtin_prefetch(&x[i + 1].d, 0, 1); + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; const uint8_t * restrict q6 = x[i].ql; @@ -1701,23 +1779,59 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int8_t * restrict scale = x[i].scales; - int sum_t = 0; - int t0; + int q6h; + float ftmp; for (int j = 0; j < QK_K/128; ++j) { __asm__ __volatile__( + "addi %[q6h], %[q6], 32\n\t" + "ld t0, 0(%[scale])\n\t" + "addi %[scale], %[scale], 8\n\t" + "slli t6, t0, 1 * 8\n\t" + "lb zero, 0(%[q6])\n\t" + "slli t5, t0, 2 * 8\n\t" + "slli t4, t0, 3 * 8\n\t" + "lb zero, 0(%[q6h])\n\t" + "slli t3, t0, 4 * 8\n\t" + "slli t2, t0, 5 * 8\n\t" + "lb zero, 0(%[qh])\n\t" + "lb zero, 31(%[q6h])\n\t" + "slli t1, t0, 6 * 8\n\t" + "srai a7, t0, 56\n\t" "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v8, (%[q6])\n\t" + "srai t6, t6, 56\n\t" + "srai t5, t5, 56\n\t" + "srai t4, t4, 56\n\t" + "srai t3, t3, 56\n\t" + "vle8.v v10, (%[q6h])\n\t" + "addi %[q6], %[q6], 64\n\t" + "slli t0, t0, 7 * 8\n\t" + "srai t2, t2, 56\n\t" + "srai t1, t1, 56\n\t" + "srai t0, t0, 56\n\t" "vle8.v v4, (%[qh])\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v10, 4\n\t" + "lb zero, 0(%[q8])\n\t" + "vand.vi v8, v8, 0xF\n\t" + "vand.vi v10, v10, 0xF\n\t" + "lb zero, 32(%[q8])\n\t" "vsll.vi v0, v4, 4\n\t" "vsll.vi v2, v4, 2\n\t" + "lb zero, 64(%[q8])\n\t" "vsrl.vi v6, v4, 2\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vle8.v v8, (%[q6])\n\t" - "vsrl.vi v12, v8, 4\n\t" - "vand.vi v8, v8, 0xF\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" "vand.vx v0, v0, %[mask]\n\t" + "lb zero, 96(%[q8])\n\t" + "vand.vx v2, v2, %[mask]\n\t" + "vand.vx v4, v4, %[mask]\n\t" + "vand.vx v6, v6, %[mask]\n\t" "vor.vv v8, v8, v0\n\t" + "lb zero, 127(%[q8])\n\t" + "vor.vv v10, v10, v2\n\t" + "vor.vv v12, v12, v4\n\t" + "vor.vv v14, v14, v6\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" "vle8.v v0, (%[q8])\n\t" "vsub.vx v8, v8, %[vl32]\n\t" "vsetvli zero, %[vl64], e8, m4\n\t" @@ -1734,34 +1848,34 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi "vwredsum.vs v13, v28, v0\n\t" "vwredsum.vs v14, v30, v0\n\t" "vsetivli zero, 4, e32, m1\n\t" - "vslideup.vi v10, v9, 1\n\t" - "vslideup.vi v8, v7, 1\n\t" - "vslideup.vi v11, v12, 1\n\t" - "vslideup.vi v13, v14, 1\n\t" - "vslideup.vi v10, v8, 2\n\t" - "vslideup.vi v11, v13, 2\n\t" - "vsetivli zero, 8, e32, m2\n\t" - "vle8.v v2, (%[scale])\n\t" - "vsext.vf4 v4, v2\n\t" - "vmul.vv v2, v4, v10\n\t" - "vredsum.vs v0, v2, v0\n\t" - "vmv.x.s %[t0], v0\n\t" - "add %[sumi], %[sumi], %[t0]" - : [sumi] "+&r" (sum_t), [t0] "=&r" (t0) - : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale) + "vmul.vx v0, v10, t0\n\t" + "vmul.vx v1, v9, t1\n\t" + "vmacc.vx v0, t2, v8\n\t" + "vmacc.vx v1, t3, v7\n\t" + "vmacc.vx v0, t4, v11\n\t" + "vmacc.vx v1, t5, v12\n\t" + "vmacc.vx v0, t6, v13\n\t" + "vmacc.vx v1, a7, v14\n\t" + "vadd.vv v0, v0, v1\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfmv.f.s %[ftmp], v0\n\t" + "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]" + : [q6] "+&r" (q6), [q6h] "=&r" (q6h) + , [scale] "+&r" (scale) + , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp) + : [qh] "r" (qh), [q8] "r" (q8) , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) - , [mask] "r" (0x30) + , [mask] "r" (0x30), [d] "f" (d) : "memory" , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7" + , "a6", "a5", "a4", "a3" ); - q6 += 64; qh += 32; q8 += 128; scale += 8; + qh += 32; q8 += 128; } - - sumf += d * sum_t; - } break; default: diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index 7e4229d0e46..dc1bba3a3e2 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -23,6 +23,27 @@ #define UNUSED GGML_UNUSED +#if defined(__VXE__) || defined(__VXE2__) +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes: +static const __attribute__((aligned(16))) uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b ) << 4 +static const __attribute__((aligned(16))) uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 + +// permute mask for byteswapping +static const uint8x16_t v_kperm = (const uint8x16_t){ + 7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8 +}; +#endif + void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); @@ -32,9 +53,9 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i #if defined(__VXE__) || defined(__VXE2__) for (int i = 0; i < nb; i++) { - __vector float srcv [8]; - __vector float asrcv[8]; - __vector float amaxv[8]; + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); @@ -53,8 +74,8 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i y[i].d = GGML_CPU_FP32_TO_FP16(d); for (int j = 0; j < 8; j++) { - const __vector float v = vec_mul(srcv[j], vec_splats(id)); - const __vector int32_t vi = vec_signed(v); + const float32x4_t v = vec_mul(srcv[j], vec_splats(id)); + const int32x4_t vi = vec_signed(v); y[i].qs[4*j + 0] = vec_extract(vi, 0); y[i].qs[4*j + 1] = vec_extract(vi, 1); @@ -77,9 +98,9 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i #if defined(__VXE__) || defined(__VXE2__) for (int i = 0; i < nb; i++) { - __vector float srcv [8]; - __vector float asrcv[8]; - __vector float amaxv[8]; + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); @@ -97,11 +118,11 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i y[i].d = GGML_CPU_FP32_TO_FP16(d); - __vector int32_t acc = vec_splats(0); + int32x4_t acc = vec_splats(0); for (int j = 0; j < 8; j++) { - const __vector float v = vec_mul(srcv[j], vec_splats(id)); - const __vector int32_t vi = vec_signed(v); + const float32x4_t v = vec_mul(srcv[j], vec_splats(id)); + const int32x4_t vi = vec_signed(v); y[i].qs[4*j + 0] = vec_extract(vi, 0); y[i].qs[4*j + 1] = vec_extract(vi, 1); @@ -141,37 +162,36 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi float sumf = 0; #if defined(__VXE__) || defined(__VXE2__) - __vector float acc = vec_splats(0.0f); + float32x4_t acc = vec_splats(0.0f); - const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F); - const __vector int8_t v_s = vec_splats( (const int8_t)0x08); + const uint8x16_t v_m = vec_splats((const uint8_t)0x0F); + const int8x16_t v_s = vec_splats( (const int8_t)0x08); for (; ib < nb; ++ib) { - const __vector uint8_t v_x = vec_xl(0, x[ib].qs); - const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m); - const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4); + const uint8x16_t v_x = vec_xl(0, x[ib].qs); + const int8x16_t v_xl = (const int8x16_t)(v_x & v_m); + const int8x16_t v_xh = (const int8x16_t)(v_x >> 4); - const __vector int8_t v_xls = vec_sub(v_xl, v_s); - const __vector int8_t v_xhs = vec_sub(v_xh, v_s); + const int8x16_t v_xls = vec_sub(v_xl, v_s); + const int8x16_t v_xhs = vec_sub(v_xh, v_s); - const __vector int8_t v_yl = vec_xl(0 , y[ib].qs); - const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs); + const int8x16_t v_yl = vec_xl(0 , y[ib].qs); + const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs); - const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl); - const __vector int16_t v_xylse = vec_mule(v_xls, v_yl); - const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh); - const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh); + const int16x8_t v_xylso = vec_mulo(v_xls, v_yl); + const int16x8_t v_xylse = vec_mule(v_xls, v_yl); + const int16x8_t v_xyhso = vec_mulo(v_xhs, v_yh); + const int16x8_t v_xyhse = vec_mule(v_xhs, v_yh); - __vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_); + int16x8_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_); - const __vector float v_xy = vec_float(vec_unpackh(v_xy_)); - const __vector float v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)); + const float32x4_t v_xy = vec_float(vec_unpackh(v_xy_)); + const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)); acc = vec_madd(v_xy, v_d, acc); } - sumf = acc[0] + acc[1] + acc[2] + acc[3]; - + sumf = vec_hsum_f32x4(acc); *s = sumf; #else UNUSED(nb); @@ -228,8 +248,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi acc = vec_madd(v_xy, v_d, acc); } - sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs; - + sumf = vec_hsum_f32x4(acc) + summs; *s = sumf; #else UNUSED(nb); @@ -241,6 +260,301 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(qk == QK5_0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0.0f; + +#if defined(__VXE__) || defined(__VXE2__) + float32x4_t v_sum0 = vec_splats(0.0f); + float32x4_t v_sum1 = vec_splats(0.0f); + + uint32_t qh0, qh1; + uint64_t tmp0[4], tmp1[4]; + + const uint8x16_t v_m = vec_splats((uint8_t)0x0F); + + #pragma GCC unroll 4 + for (; ib + 1 < nb; ib += 2) { + const block_q5_0 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_1[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_1[(qh1 >> 24) ]; + + int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0)); + int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2)); + int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0)); + int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2)); + + // required for fixing the byteorder + v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm); + v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm); + v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm); + v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm); + + const uint8x16_t v_x0 = vec_xl(0, (const uint8_t *)x0->qs); + const uint8x16_t v_x1 = vec_xl(0, (const uint8_t *)x1->qs); + + int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m); + int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4); + int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m); + int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4); + + const int8x16_t v_x0lf = vec_sub(v_x0l, v_qh0l); + const int8x16_t v_x0hf = vec_sub(v_x0h, v_qh0h); + const int8x16_t v_x1lf = vec_sub(v_x1l, v_qh1l); + const int8x16_t v_x1hf = vec_sub(v_x1h, v_qh1h); + + const int8x16_t v_y0l = vec_xl(0, (const int8_t *)y0->qs); + const int8x16_t v_y0h = vec_xl(QK8_0/2, (const int8_t *)y0->qs); + const int8x16_t v_y1l = vec_xl(0, (const int8_t *)y1->qs); + const int8x16_t v_y1h = vec_xl(QK8_0/2, (const int8_t *)y1->qs); + + const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h); + const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h); + + const float32x4_t v_xy0f = vec_float(v_xy0); + const float32x4_t v_xy1f = vec_float(v_xy1); + + const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); + const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d)); + + v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0); + v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1); + } + + sumf += vec_hsum_f32x4(v_sum0) + vec_hsum_f32x4(v_sum1); + + #pragma GCC unroll 4 + for (; ib < nb; ++ib) { + const block_q5_0 * GGML_RESTRICT x0 = &x[ib]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; + + uint32_t qh; + memcpy(&qh, x0->qh, sizeof(qh)); + + uint64_t tmp[4]; + tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh >> 24) ]; + + int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0)); + int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2)); + + // required for fixing the byteorder + v_qhl = vec_perm(v_qhl, v_qhl, v_kperm); + v_qhh = vec_perm(v_qhh, v_qhh, v_kperm); + + const uint8x16_t v_x = vec_xl(0, (const uint8_t *)x0->qs); + int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m); + int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4); + + const int8x16_t v_xlf = vec_sub(v_xl, v_qhl); + const int8x16_t v_xhf = vec_sub(v_xh, v_qhh); + + const int8x16_t v_yl = vec_xl(0, (const int8_t *)y0->qs); + const int8x16_t v_yh = vec_xl(QK8_0/2, (const int8_t *)y0->qs); + + const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh); + const float32x4_t v_xyf = vec_float(v_xy); + + const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); + const float32x4_t v_acc = vec_madd(v_xyf, v_d, vec_splats(0.0f)); + + sumf += vec_hsum_f32x4(v_acc); + } + + *s = sumf; +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + UNUSED(ib); + UNUSED(sumf); + ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(qk == QK5_1); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0.0f; + +#if defined(__VXE__) || defined(__VXE2__) + float32x4_t v_sum0 = vec_splats(0.0f); + float32x4_t v_sum1 = vec_splats(0.0f); + + float summs0 = 0.0f; + float summs1 = 0.0f; + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + const uint8x16_t v_m = vec_splats((uint8_t)0x0F); + + #pragma GCC unroll 4 + for (; ib + 1 < nb; ib += 2) { + const block_q5_1 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1]; + + summs0 += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s); + summs1 += GGML_CPU_FP16_TO_FP32(x1->m) * GGML_CPU_FP16_TO_FP32(y1->s); + + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_0[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_0[(qh1 >> 24) ]; + + int8x16_t v_qh0l = vec_xl(0, (const int8_t *)(tmp0 + 0)); + int8x16_t v_qh0h = vec_xl(0, (const int8_t *)(tmp0 + 2)); + int8x16_t v_qh1l = vec_xl(0, (const int8_t *)(tmp1 + 0)); + int8x16_t v_qh1h = vec_xl(0, (const int8_t *)(tmp1 + 2)); + + // required for fixing the byteorder + v_qh0l = vec_perm(v_qh0l, v_qh0l, v_kperm); + v_qh0h = vec_perm(v_qh0h, v_qh0h, v_kperm); + v_qh1l = vec_perm(v_qh1l, v_qh1l, v_kperm); + v_qh1h = vec_perm(v_qh1h, v_qh1h, v_kperm); + + const uint8x16_t v_x0 = vec_xl(0, x0->qs); + const uint8x16_t v_x1 = vec_xl(0, x1->qs); + + const int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m); + const int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4); + const int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m); + const int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4); + + const int8x16_t v_x0lf = vec_or(v_x0l, v_qh0l); + const int8x16_t v_x0hf = vec_or(v_x0h, v_qh0h); + const int8x16_t v_x1lf = vec_or(v_x1l, v_qh1l); + const int8x16_t v_x1hf = vec_or(v_x1h, v_qh1h); + + const int8x16_t v_y0l = vec_xl(0 , y0->qs); + const int8x16_t v_y0h = vec_xl(QK8_1/2, y0->qs); + const int8x16_t v_y1l = vec_xl(0 , y1->qs); + const int8x16_t v_y1h = vec_xl(QK8_1/2, y1->qs); + + const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0lf, v_y0l), v_x0hf, v_y0h); + const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1lf, v_y1l), v_x1hf, v_y1h); + + const float32x4_t v_xy0f = vec_float(v_xy0); + const float32x4_t v_xy1f = vec_float(v_xy1); + + const float32x4_t v_d0 = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); + const float32x4_t v_d1 = vec_splats(GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d)); + + v_sum0 = vec_madd(v_xy0f, v_d0, v_sum0); + v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1); + } + + sumf += vec_hsum_f32x4(v_sum0) + vec_hsum_f32x4(v_sum1) + summs0 + summs1; + + #pragma GCC unroll 4 + for (; ib < nb; ++ib) { + const block_q5_1 * GGML_RESTRICT x0 = &x[ib]; + const block_q8_1 * GGML_RESTRICT y0 = &y[ib]; + + float summs = GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s); + + uint32_t qh; + memcpy(&qh, x0->qh, sizeof(qh)); + + uint64_t tmp[4]; + tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh >> 24) ]; + + int8x16_t v_qhl = vec_xl(0, (const int8_t *)(tmp + 0)); + int8x16_t v_qhh = vec_xl(0, (const int8_t *)(tmp + 2)); + + // required for fixing the byteorder + v_qhl = vec_perm(v_qhl, v_qhl, v_kperm); + v_qhh = vec_perm(v_qhh, v_qhh, v_kperm); + + const uint8x16_t v_x = vec_xl(0, x0->qs); + const int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m); + const int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4); + + const int8x16_t v_xlf = vec_or(v_xl, v_qhl); + const int8x16_t v_xhf = vec_or(v_xh, v_qhh); + + const int8x16_t v_yl = vec_xl(0 , y0->qs); + const int8x16_t v_yh = vec_xl(QK8_1/2, y0->qs); + + const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xlf, v_yl), v_xhf, v_yh); + const float32x4_t v_xyf = vec_float(v_xy); + + const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); + const float32x4_t v_acc = vec_madd(v_xyf, v_d, v_acc); + + sumf += vec_hsum_f32x4(v_acc) + summs; + } + + *s = sumf; +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + UNUSED(ib); + UNUSED(sumf); + ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -259,7 +573,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi float sumf = 0; #if defined(__VXE__) || defined(__VXE2__) - __vector float acc = vec_splats(0.0f); + float32x4_t acc = vec_splats(0.0f); #pragma GCC unroll 8 for (; ib < nb; ++ib) { @@ -278,7 +592,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi acc = vec_madd(v_xy, v_d, acc); } - sumf = acc[0] + acc[1] + acc[2] + acc[3]; + sumf = vec_hsum_f32x4(acc); *s = sumf; #else @@ -402,10 +716,10 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]); isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]); - isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0]; - isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1]; - isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2]; - isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3]; + isum += vec_hsum_i32x4(isum0) * scale[0]; + isum += vec_hsum_i32x4(isum1) * scale[1]; + isum += vec_hsum_i32x4(isum2) * scale[2]; + isum += vec_hsum_i32x4(isum3) * scale[3]; scale += 4; @@ -503,7 +817,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm); const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); - sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0]; + sumi1 += vec_hsum_i32x4(p1) * scales[2*j+0]; v_y[0] = vec_xl(0 , y0); v_y[1] = vec_xl(16, y0); @@ -513,7 +827,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4); const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); - sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1]; + sumi2 += vec_hsum_i32x4(p2) * scales[2*j+1]; } sumf += d * (sumi1 + sumi2); @@ -595,7 +909,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh); const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh); const int32x4_t v_mins = vec_add(v_minsho, v_minshe); - const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; + const int32_t mins = vec_hsum_i32x4(v_mins); const uint8_t * scales = (const uint8_t *)utmp; const uint8_t * GGML_RESTRICT x0l = x[i].qs; @@ -632,8 +946,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]); int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]); - sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++; - sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++; + sumi += vec_hsum_i32x4(sumi0) * *scales++; + sumi += vec_hsum_i32x4(sumi1) * *scales++; } sumf += d * sumi - dmin * mins; @@ -704,7 +1018,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh); const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe; - const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; + const int32_t mins = vec_hsum_i32x4(v_mins); int32_t isum = 0; for (int j = 0; j < QK_K/128; ++j) { @@ -744,10 +1058,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); - isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + - (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + - (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + - (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; + isum += vec_hsum_i32x4(summs0) * scale[0] + + vec_hsum_i32x4(summs1) * scale[1] + + vec_hsum_i32x4(summs2) * scale[2] + + vec_hsum_i32x4(summs3) * scale[3]; scale += 4; @@ -778,10 +1092,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); - isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + - (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + - (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + - (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; + isum += vec_hsum_i32x4(summs0) * scale[0] + + vec_hsum_i32x4(summs1) * scale[1] + + vec_hsum_i32x4(summs2) * scale[2] + + vec_hsum_i32x4(summs3) * scale[3]; scale += 4; } @@ -969,7 +1283,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs); const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); - sumf += GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]); + sumf += GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d) * vec_hsum_i32x4(v_xy); } *s = sumf; @@ -1038,8 +1352,8 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v h >>= 4; - sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1; - sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2; + sumi1 += vec_hsum_i32x4(vsumi0) * ls1; + sumi2 += vec_hsum_i32x4(vsumi1) * ls2; } sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); diff --git a/ggml/src/ggml-cpu/common.h b/ggml/src/ggml-cpu/common.h index 353563dc35c..6adca5437f8 100644 --- a/ggml/src/ggml-cpu/common.h +++ b/ggml/src/ggml-cpu/common.h @@ -28,6 +28,14 @@ static inline float bf16_to_f32(ggml_bf16_t x) { return GGML_BF16_TO_FP32(x); } +static inline float i32_to_f32(int32_t x) { + return x; +} + +static inline int32_t f32_to_i32(float x) { + return x; +} + static inline float f32_to_f32(float x) { return x; } @@ -54,6 +62,12 @@ struct type_conversion_table { static constexpr ggml_bf16_t (*from_f32)(float) = f32_to_bf16; }; +template <> +struct type_conversion_table { + static constexpr float (*to_f32)(int32_t) = i32_to_f32; + static constexpr int32_t (*from_f32)(float) = f32_to_i32; +}; + static std::pair get_thread_range(const struct ggml_compute_params * params, const struct ggml_tensor * src0) { const int64_t ith = params->ith; const int64_t nth = params->nth; diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index d839cf5c55e..799e2b11872 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -68,12 +68,6 @@ struct ggml_compute_params { #endif // __VXE2__ #endif // __s390x__ && __VEC__ -#if defined(__s390x__) && defined(GGML_NNPA) -#ifndef __NNPA__ -#define __NNPA__ -#endif // __NNPA__ -#endif // __s390x__ && GGML_NNPA - #if defined(__ARM_FEATURE_SVE) #include #endif @@ -486,6 +480,19 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) { return v_abo + v_abe; } +/** + * @see https://github.com/ggml-org/llama.cpp/pull/14037 + */ +inline static float vec_hsum_f32x4(float32x4_t v) { + float32x4_t v_temp = v + vec_reve(v); + return v_temp[0] + v_temp[1]; +} + +inline static int32_t vec_hsum_i32x4(int32x4_t v) { + int32x4_t v_temp = v + vec_reve(v); + return v_temp[0] + v_temp[1]; +} + inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) { const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b); return acc + (vec_unpackh(p) + vec_unpackl(p)); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index f6bea3df34a..c1312908495 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -373,6 +373,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_I32] = { + .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32, + }, }; const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { @@ -1876,10 +1879,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_im2col_back_f32(params, tensor); } break; + case GGML_OP_IM2COL_3D: + { + ggml_compute_forward_im2col_3d(params, tensor); + } break; case GGML_OP_CONV_2D: { ggml_compute_forward_conv_2d(params, tensor); } break; + case GGML_OP_CONV_3D: + { + ggml_compute_forward_conv_3d(params, tensor); + } break; case GGML_OP_CONV_2D_DW: { ggml_compute_forward_conv_2d_dw(params, tensor); @@ -2251,7 +2262,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_IM2COL: case GGML_OP_IM2COL_BACK: + case GGML_OP_IM2COL_3D: case GGML_OP_CONV_2D: + case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: @@ -2686,7 +2699,10 @@ struct ggml_cplan ggml_graph_plan( if (ggml_is_quantized(node->type) || // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) || - (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) { + (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16) || + // conversion between F32 and I32 + (node->src[0]->type == GGML_TYPE_F32 && node->src[1] && node->src[1]->type == GGML_TYPE_I32) || + (node->src[0]->type == GGML_TYPE_I32 && node->src[1] && node->src[1]->type == GGML_TYPE_F32)) { cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; } } break; @@ -2773,6 +2789,7 @@ struct ggml_cplan ggml_graph_plan( } } break; case GGML_OP_CONV_2D: + case GGML_OP_CONV_3D: { cur = GGML_IM2COL_WORK_SIZE; } break; @@ -3200,20 +3217,12 @@ void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) { __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); _mm_storel_epi64((__m128i *)(y + i), y_vec); } -#elif defined(__NNPA__) - for (; i + 7 < n; i += 8) { - float32x4_t v_xh = vec_xl(0, (const float *)(x + i + 0)); - float32x4_t v_xl = vec_xl(0, (const float *)(x + i + 4)); - uint16x8_t v_yd = vec_round_from_fp32(v_xh, v_xl, 0); - uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0); - vec_xst(v_y, 0, (ggml_fp16_t *)(y + i)); - } - for (; i + 3 < n; i += 4) { - float32x4_t v_x = vec_xl(0, (const float *)(x + i)); - float32x4_t v_zero = vec_splats(0.0f); - uint16x8_t v_yd = vec_round_from_fp32(v_x, v_zero, 0); - uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0); - vec_xst(v_y, 0, (ggml_fp16_t *)(y + i)); +#elif defined(__riscv_zvfh) + for (int vl; i < n; i += vl) { + vl = __riscv_vsetvl_e32m2(n - i); + vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl); + vfloat16m1_t vy = __riscv_vfncvt_f_f_w_f16m1(vx, vl); + __riscv_vse16_v_f16m1((_Float16 *)&y[i], vy, vl); } #endif for (; i < n; ++i) { @@ -3241,21 +3250,6 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) { __m128 y_vec = _mm_cvtph_ps(x_vec); _mm_storeu_ps(y + i, y_vec); } -#elif defined(__NNPA__) - for (; i + 7 < n; i += 8) { - uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)(x + i)); - uint16x8_t v_yd = vec_convert_from_fp16(v_x, 0); - float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0); - float32x4_t v_yl = vec_extend_to_fp32_lo(v_yd, 0); - vec_xst(v_yh, 0, (float *)(y + i + 0)); - vec_xst(v_yl, 0, (float *)(y + i + 4)); - } - for (; i + 3 < n; i += 4) { - uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)(x + i)); - uint16x8_t v_yd = vec_convert_from_fp16(v_x, 0); - float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0); - vec_xst(v_yh, 0, (float *)(y + i)); - } #endif for (; i < n; ++i) { @@ -3270,6 +3264,13 @@ void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) { } } +void ggml_cpu_fp32_to_i32(const float * x, int32_t * y, int64_t n) { + int64_t i = 0; + for (; i < n; ++i) { + y[i] = x[i]; + } +} + void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) { int64_t i = 0; #if defined(__AVX2__) @@ -3459,14 +3460,6 @@ int ggml_cpu_has_vxe(void) { #endif } -int ggml_cpu_has_nnpa(void) { -#if defined(GGML_NNPA) - return 1; -#else - return 0; -#endif -} - int ggml_cpu_has_neon(void) { #if defined(__ARM_ARCH) && defined(__ARM_NEON) return 1; diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 8dacd36714b..81a314e4d68 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -190,6 +190,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = { /* .graph_compute = */ ggml_backend_cpu_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, }; static ggml_guid_t ggml_backend_cpu_guid(void) { @@ -348,8 +349,10 @@ static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * long pages = sysconf(_SC_PHYS_PAGES); long page_size = sysconf(_SC_PAGE_SIZE); *total = pages * page_size; + + // "free" system memory is ill-defined, for practical purposes assume that all of it is free: *free = *total; -#endif +#endif // _WIN32 GGML_UNUSED(dev); } @@ -576,9 +579,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r if (ggml_cpu_has_vxe()) { features.push_back({ "VXE", "1" }); } - if (ggml_cpu_has_nnpa()) { - features.push_back({ "NNPA", "1" }); - } if (ggml_cpu_has_wasm_simd()) { features.push_back({ "WASM_SIMD", "1" }); } diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index ddd29d002d1..7ba659124ca 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -14,6 +14,7 @@ #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" #include "kai_lhs_quant_pack_qsi8d32p_f32.h" +#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h" #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" #include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" @@ -127,6 +128,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, }, + /* .gemm_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, + }, /* SME GEMV */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, @@ -141,7 +148,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, }, - /* .lhs_info = */ { + /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon, /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, @@ -173,6 +180,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, }, + /* .gemm_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, + /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, + }, /* SME GEMV */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, @@ -187,7 +200,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, }, - /* .lhs_info = */ { + /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, @@ -222,6 +235,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, }, + /* .gemm_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + }, /* DOTPROD GEMV */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, @@ -236,7 +255,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, }, - /* .lhs_info = */ { + /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, @@ -270,6 +289,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, }, + /* .gemm_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + }, /* i8mm GEMV */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, @@ -284,7 +309,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, }, - /* .lhs_info = */ { + /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, @@ -319,6 +344,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, }, + /* .gemm_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + }, /* i8mm GEMV */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, @@ -333,7 +364,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, }, - /* .lhs_info = */ { + /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, @@ -367,6 +398,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, }, + /* .gemm_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + }, /* DOTPROD GEMV */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, @@ -381,7 +418,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, }, - /* .lhs_info = */ { + /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h index bc8f33405d1..2ad6ad6fd0b 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -84,8 +84,11 @@ struct rhs_packing_info { struct ggml_kleidiai_kernels { kernel_info gemm; + lhs_packing_info gemm_lhs_info; + kernel_info gemv; - lhs_packing_info lhs_info; + lhs_packing_info gemv_lhs_info; + rhs_packing_info rhs_info; cpu_feature required_cpu; diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index dff8fa244a1..8694ee15d3f 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -123,7 +123,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { } ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); GGML_ASSERT(kernels); - kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; + bool is_gemv = op->src[1]->ne[1] == 1; + kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; + lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; size_t k = op->src[0]->ne[0]; size_t n = op->src[0]->ne[1]; @@ -134,9 +136,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t sr = kernel->get_sr(); if (kernels->rhs_type == GGML_TYPE_Q4_0) { - size = variant_call(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr); + size = variant_call(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr); } else if (kernels->rhs_type == GGML_TYPE_F16) { - size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr) + + size = variant_call(lhs_info->packed_size, m, k, mr, kr, sr) + variant_call(kernels->rhs_info.packed_size, n, k) + k * n * sizeof(float) + n * sizeof(float); } else { @@ -152,7 +154,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (dst->src[0]->type == GGML_TYPE_Q4_0) { return compute_forward_q4_0(params, dst); } else if (dst->src[0]->type == GGML_TYPE_F16) { - return compute_forward_kv_cache(params, dst); + return compute_forward_fp16(params, dst); } } else if (dst->op == GGML_OP_GET_ROWS) { if (dst->src[0]->type == GGML_TYPE_Q4_0) { @@ -162,7 +164,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { return false; } - bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) { + bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) { static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT; const ggml_tensor * src0 = dst->src[0]; @@ -173,7 +175,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); GGML_ASSERT(kernels); - kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; + bool is_gemv = src1->ne[1] == 1; + kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; + lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; GGML_ASSERT(kernel); const int nth = params->nth; @@ -198,7 +202,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t kr = static_cast(kernel->get_kr()); const int64_t sr = static_cast(kernel->get_sr()); - const size_t lhs_packed_size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr); + const size_t lhs_packed_size = variant_call(lhs_info->packed_size, m, k, mr, kr, sr); const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, n, k); const size_t kxn_size = k * n * sizeof(float); const size_t bias_size = n * sizeof(float); @@ -229,12 +233,12 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; const size_t lhs_offset = variant_call(kernels->gemm.get_lhs_offset, m_start, lhs_stride); - const size_t lhs_packed_offset = variant_call(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr); + const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, m_start, k, mr, kr, sr); const void * src_ptr = static_cast(lhs_batch) + lhs_offset; void * dst_ptr = static_cast(lhs_packed) + lhs_packed_offset; - variant_call(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); + variant_call(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); } } @@ -306,8 +310,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); GGML_ASSERT(kernels); - kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; - lhs_packing_info * lhs_info = &kernels->lhs_info; + bool is_gemv = src1->ne[1] == 1; + kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; + lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; GGML_ASSERT(kernel); @@ -510,9 +515,6 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) { - if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) { - return false; - } if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { return false; } @@ -529,13 +531,8 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } - else if (ggml_kleidiai_select_kernels(ctx.features, op) && - op->src[0]->op == GGML_OP_VIEW && - (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) && - op->src[1]->ne[1] > 1) { - if ((op->src[0]->nb[0] != 2) || - (op->src[1]->nb[0] != 4) || - (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || + else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) { + if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { return nullptr; } diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 2be54c31b5f..2c4ad9d58b9 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -2169,94 +2169,117 @@ class tinyBLAS_Q0_PPC { class tinyBLAS_PPC { public: tinyBLAS_PPC(int64_t k, - const float *A, int64_t lda, - const float *B, int64_t ldb, - float *C, int64_t ldc, + const float * A, int64_t lda, + const float * B, int64_t ldb, + float * C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } void matmul(int64_t m, int64_t n) { - mnpack(0, m, 0, n); + int64_t mc = 256; int64_t nc = 256; int64_t kc = 256; + if (m % mc == 0 && n % nc == 0 && k % kc == 0) { + matmul_tiled(m, n, mc, nc, kc); + } else { + mnpack(0, m, 0, n); + } } private: - void (tinyBLAS_PPC::*kernel)(int64_t, int64_t); + inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J); + } + } + } - inline void vector_permute_store_4(vector float *src, float *vecOffset) { - vector float t1, t2, t3, t4, t5, t6, t7, t8; - t1 = vec_mergeh(src[0], src[1]); - t2 = vec_mergeh(src[2], src[3]); - t3 = vec_mergel(src[0], src[1]); - t4 = vec_mergel(src[2], src[3]); + inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I); + *c_ptr += *((float *)&vec_C[I]+J); + } + } + } - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t1, t2, 3); - t7 = vec_xxpermdi(t3, t4, 0); - t8 = vec_xxpermdi(t3, t4, 3); + inline void vector_permute_store_4(vector float * src, float * vecOffset) { + vector float t1, t2, t3, t4, t5, t6, t7, t8; + t1 = vec_mergeh(src[0], src[1]); + t2 = vec_mergeh(src[2], src[3]); + t3 = vec_mergel(src[0], src[1]); + t4 = vec_mergel(src[2], src[3]); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset + 4); - vec_xst(t7, 0, vecOffset + 8); - vec_xst(t8, 0, vecOffset + 12); - } + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t1, t2, 3); + t7 = vec_xxpermdi(t3, t4, 0); + t8 = vec_xxpermdi(t3, t4, 3); - inline void vector_permute_store_8(vector float *src, float *vecOffset) { - vector float t1, t2, t3, t4, t5, t6, t7, t8; - t1 = vec_mergeh(src[0], src[1]); - t2 = vec_mergeh(src[2], src[3]); - t3 = vec_mergeh(src[4], src[5]); - t4 = vec_mergeh(src[6], src[7]); + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 4); + vec_xst(t7, 0, vecOffset + 8); + vec_xst(t8, 0, vecOffset + 12); + } - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); + inline void vector_permute_store_8(vector float * src, float * vecOffset) { + vector float t1, t2, t3, t4, t5, t6, t7, t8; + t1 = vec_mergeh(src[0], src[1]); + t2 = vec_mergeh(src[2], src[3]); + t3 = vec_mergeh(src[4], src[5]); + t4 = vec_mergeh(src[6], src[7]); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset + 4); - vec_xst(t7, 0, vecOffset + 8); - vec_xst(t8, 0, vecOffset + 12); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); - t1 = vec_mergel(src[0], src[1]); - t2 = vec_mergel(src[2], src[3]); - t3 = vec_mergel(src[4], src[5]); - t4 = vec_mergel(src[6], src[7]); + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 4); + vec_xst(t7, 0, vecOffset + 8); + vec_xst(t8, 0, vecOffset + 12); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); + t1 = vec_mergel(src[0], src[1]); + t2 = vec_mergel(src[2], src[3]); + t3 = vec_mergel(src[4], src[5]); + t4 = vec_mergel(src[6], src[7]); - vec_xst(t5, 0, vecOffset + 16); - vec_xst(t6, 0, vecOffset + 20); - vec_xst(t7, 0, vecOffset + 24); - vec_xst(t8, 0, vecOffset + 28); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset + 16); + vec_xst(t6, 0, vecOffset + 20); + vec_xst(t7, 0, vecOffset + 24); + vec_xst(t8, 0, vecOffset + 28); } - void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) { + void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) { int64_t i, j; float * aoffsets[8]; - float *aoffset = NULL, *boffset = NULL; + float * aoffset = NULL, * boffset = NULL; __vector_pair arr[8]; vector float c[8][2] = {0}; vector float c1[8] = {0}; vector float c2[8] = {0}; - aoffset = const_cast(a); + aoffset = const_cast(a); boffset = vec; j = (rows >> 3); if (j > 0) { - do { aoffsets[0] = aoffset; - for (int it = 1; it< 8; it++) + for (int it = 1; it < 8; it++) aoffsets[it] = aoffsets[it-1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { - for (int it = 0; it< 8; it++) { + for (int it = 0; it < 8; it++) { arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]); __builtin_vsx_disassemble_pair(c[it], &arr[it]); c1[it] = c[it][0]; @@ -2264,11 +2287,14 @@ class tinyBLAS_PPC { } vector_permute_store_8(c1, boffset); - vector_permute_store_8(c2, boffset+32); - for (int it = 0; it < 4; it++) - aoffsets[it] = aoffsets[it] + 8*lda; + vector_permute_store_8(c2, boffset + 32); boffset += 64; i--; + if (i > 0) { + for (int it = 0; it < 8; it++) { + aoffsets[it] = aoffsets[it] + 8; + } + } } while(i > 0); } if (cols & 4) { @@ -2295,9 +2321,9 @@ class tinyBLAS_PPC { c2[it] = c[it][1]; } vector_permute_store_4(c1, boffset); - vector_permute_store_4(c2, boffset+16); + vector_permute_store_4(c2, boffset + 16); for (int it = 0; it < 4; it++) - aoffsets[it] += 8*lda; + aoffsets[it] += 8 * lda; boffset += 32; i--; } while(i > 0); @@ -2325,15 +2351,15 @@ class tinyBLAS_PPC { vec_t vec_A[4], vec_B[4], vec_C[4]; acc_t acc_0; __builtin_mma_xxsetaccz(&acc_0); - for (int l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); + for (int l = 0; l < k; l += 4) { + packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A); + packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]); } - SAVE_ACC(&acc_0, ii, jj); + save_acc(&acc_0, ii, jj); } void KERNEL_4x8(int64_t ii, int64_t jj) { @@ -2341,9 +2367,9 @@ class tinyBLAS_PPC { acc_t acc_0, acc_1; __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); - for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B); + for (int64_t l = 0; l < k; l += 4) { + packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A); + packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]); @@ -2353,8 +2379,8 @@ class tinyBLAS_PPC { __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]); __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]); } - SAVE_ACC(&acc_0, ii, jj); - SAVE_ACC(&acc_1, ii, jj+4); + save_acc(&acc_0, ii, jj); + save_acc(&acc_1, ii, jj + 4); } void KERNEL_8x4(int64_t ii, int64_t jj) { @@ -2362,9 +2388,9 @@ class tinyBLAS_PPC { acc_t acc_0, acc_1; __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); - for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); + for (int64_t l = 0; l < k; l += 4) { + packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A); + packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]); @@ -2374,8 +2400,8 @@ class tinyBLAS_PPC { __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]); } - SAVE_ACC(&acc_0, ii, jj); - SAVE_ACC(&acc_1, ii+4, jj); + save_acc(&acc_0, ii, jj); + save_acc(&acc_1, ii + 4, jj); } void KERNEL_8x8(int64_t ii, int64_t jj) { @@ -2386,19 +2412,96 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); for (int l = 0; l < k; l+=8) { - packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B); + packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A); + packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B); for(int x = 0; x < 16; x+=2) { __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]); - __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]); - __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]); - __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]); + __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]); + __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]); + } + } + save_acc(&acc_0, ii, jj); + save_acc(&acc_1, ii, jj + 4); + save_acc(&acc_2, ii + 4, jj); + save_acc(&acc_3, ii + 4, jj + 4); + } + + inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) { + for (int x = 0; x < 16; x += 2) { + __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]); + __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]); + __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]); + __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]); + } + } + + void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) { + for (int64_t i = 0; i < mc; i += 16) { + int A_base_addr = (mc / 8) * (i / 8) * 16; + for (int64_t j = 0; j < nc; j += 8) { + int B_base_addr = (nc / 8) * (j / 8) * 16; + acc_t acc[8]; + vec_t A0_block[16]; vec_t A1_block[16]; + for (int x = 0; x < 8; x++) + __builtin_mma_xxsetaccz(&acc[x]); + for (int64_t l = 0; l < kc; l += 8) { + int A0_block_idx = A_base_addr + (l / 8) * 16; + int A1_block_idx = A0_block_idx + (mc / 8) * 16; + int B_block_idx = B_base_addr + (l / 8) * 16; + vec_t* A0_block = &vec_A[A0_block_idx]; + vec_t* A1_block = &vec_A[A1_block_idx]; + vec_t* B_block = &vec_B[B_block_idx]; + MMA_16x8(A0_block, A1_block, B_block, acc); + } + if (kk == 0) { + save_acc(&acc[0], ii + i, jj + j); + save_acc(&acc[1], ii + i, jj + j + 4); + save_acc(&acc[2], ii + i + 4, jj + j); + save_acc(&acc[3], ii + i + 4, jj + j + 4); + save_acc(&acc[4], ii + i + 8, jj + j); + save_acc(&acc[5], ii + i + 8, jj + j + 4); + save_acc(&acc[6], ii + i + 12, jj + j); + save_acc(&acc[7], ii + i + 12, jj + j + 4); + } else { + add_save_acc(&acc[0], ii + i, jj + j); + add_save_acc(&acc[1], ii + i, jj + j + 4); + add_save_acc(&acc[2], ii + i + 4, jj + j); + add_save_acc(&acc[3], ii + i + 4, jj + j + 4); + add_save_acc(&acc[4], ii + i + 8, jj + j); + add_save_acc(&acc[5], ii + i + 8, jj + j + 4); + add_save_acc(&acc[6], ii + i + 12, jj + j); + add_save_acc(&acc[7], ii + i + 12, jj + j + 4); + } + } + } + } + + void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) { + int64_t ytiles = m / mc; + int64_t xtiles = n / nc; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) { + end = tiles; + } + for (int64_t job = start; job < end; ++job) { + int64_t ii = (job / xtiles) * mc; + int64_t jj = (job % xtiles) * nc; + for (int64_t kk = 0; kk < k; kk += kc) { + vec_t A_pack[kc * mc / 4]; + vec_t B_pack[kc * nc / 4]; + packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack); + packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack); + KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk); } } - SAVE_ACC(&acc_0, ii, jj); - SAVE_ACC(&acc_1, ii, jj+4); - SAVE_ACC(&acc_2, ii+4, jj); - SAVE_ACC(&acc_3, ii+4, jj+4); } void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { @@ -2406,35 +2509,35 @@ class tinyBLAS_PPC { int n_rem = MIN(n - n0, 8); int mc = 0, nc = 0; if (m_rem >= 8 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8, 8>(m0, m, n0, n); + mc = 8; + nc = 8; + gemm<8, 8>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 8) { - mc = 4; - nc = 8; - gemm<4, 8>(m0, m, n0, n); + mc = 4; + nc = 8; + gemm<4, 8>(m0, m, n0, n); } else if (m_rem >= 8 && n_rem >= 4) { - mc = 8; - nc = 4; - gemm<8, 4>(m0, m, n0, n); + mc = 8; + nc = 4; + gemm<8, 4>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 4) { - mc = 4; - nc = 4; - gemm<4, 4>(m0, m, n0, n); + mc = 4; + nc = 4; + gemm<4, 4>(m0, m, n0, n); } else { mc = (m_rem >= 4) ? 4 : m_rem; nc = (n_rem >= 4) ? 4 : n_rem; if (mc == 0 || nc == 0) - return; + return; gemm_small(m0, m, n0, n, mc, nc); } int64_t mp = m0 + ((m - m0) / mc) * mc; int64_t np = n0 + ((n - n0) / nc) * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); - } + } - void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2449,30 +2552,30 @@ class tinyBLAS_PPC { vec_t vec_C[4]; acc_t acc_0; __builtin_mma_xxsetaccz(&acc_0); - vec_t vec_A[4] {0}, vec_B[4] = {0}; - for (int l=0; l(A+(ii)*lda+l); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); + float * a = const_cast(A + (ii) * lda + l); + packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B); vec_A[0] = (vec_t)vec_xl(0,a); - vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1)); - vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2)); - vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3)); + vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1)); + vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2)); + vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3)); } else if (RN == 1) { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); - float* b = const_cast(B+(jj)*ldb+l); + packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A); + float * b = const_cast(B + (jj) * ldb + l); vec_B[0] = (vec_t)vec_xl(0,b); - vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1)); - vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2)); - vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3)); + vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1)); + vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2)); + vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3)); } else { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); + packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A); + packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B); } __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); @@ -2482,12 +2585,27 @@ class tinyBLAS_PPC { __builtin_mma_disassemble_acc(vec_C, &acc_0); for (int I = 0; I < RM; I++) { for (int J = 0; J < RN; J++) { - *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); + *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J); } } } } + template + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 4) { + KERNEL_4x4(ii, jj); + } else if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii, jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii, jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii, jj); + } else { + static_assert(false, "RN/RM values not supported"); + } + } + template NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / RM; @@ -2496,27 +2614,18 @@ class tinyBLAS_PPC { int64_t duty = (tiles + nth - 1) / nth; int64_t start = duty * ith; int64_t end = start + duty; - if (RM == 4 && RN == 4) { - kernel = &tinyBLAS_PPC::KERNEL_4x4; - } else if (RM == 4 && RN == 8) { - kernel = &tinyBLAS_PPC::KERNEL_4x8; - } else if (RM == 8 && RN == 4) { - kernel = &tinyBLAS_PPC::KERNEL_8x4; - } else if (RM == 8 && RN == 8) { - kernel = &tinyBLAS_PPC::KERNEL_8x8; - } if (end > tiles) end = tiles; for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; - (this->*kernel)(ii, jj); + kernel(ii, jj); } } - const float *const A; - const float *const B; - float *C; + const float * const A; + const float * const B; + float * C; const int64_t k; const int64_t lda; const int64_t ldb; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b72a2556a5f..763ab099e31 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -41,13 +41,15 @@ static void ggml_compute_forward_dup_same_cont( } } -static void ggml_compute_forward_dup_f16( +template +static void ggml_compute_forward_dup_flt( const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type)); GGML_TENSOR_UNARY_OP_LOCALS @@ -62,6 +64,7 @@ static void ggml_compute_forward_dup_f16( const int ir0 = dr * ith; const int ir1 = MIN(ir0 + dr, nr); + // case: type & row size equal if (src0->type == dst->type && ne00 == ne0 && nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { @@ -80,11 +83,11 @@ static void ggml_compute_forward_dup_f16( return; } - // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy - + // case: dst tensor is contiguous if (ggml_is_contiguous(dst)) { - if (nb00 == sizeof(ggml_fp16_t)) { - if (dst->type == GGML_TYPE_F16) { + if (nb00 == sizeof(src_t)) { + if constexpr (std::is_same_v) { + // same type size_t id = 0; const size_t rs = ne00 * nb00; char * dst_ptr = (char *) dst->data; @@ -100,750 +103,58 @@ static void ggml_compute_forward_dup_f16( id += rs * (ne01 - ir1); } } - } else if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - for (int i00 = 0; i00 < ne00; i00++) { - dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (ggml_get_type_traits_cpu(dst->type)->from_float) { - ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float; - float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; - - size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - for (int i00 = 0; i00 < ne00; i00++) { - src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]); - } - - quantize_row_q(src0_f32, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } - return; - } - - // dst counters - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - if (dst->type == GGML_TYPE_F16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t)); - - if (++i10 == ne00) { - i10 = 0; - if (++i11 == ne01) { - i11 = 0; - if (++i12 == ne02) { - i12 = 0; - if (++i13 == ne03) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_F32) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } -} - -static void ggml_compute_forward_dup_bf16( - const ggml_compute_params * params, - ggml_tensor * dst) { - - const ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - - GGML_TENSOR_UNARY_OP_LOCALS - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - // parallelize by rows - const int nr = ne01; - // number of rows per thread - const int dr = (nr + nth - 1) / nth; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { - // copy by rows - const size_t rs = ne00*nb00; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } - - // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy - - if (ggml_is_contiguous(dst)) { - if (nb00 == sizeof(ggml_bf16_t)) { - if (dst->type == GGML_TYPE_BF16) { - size_t id = 0; - const size_t rs = ne00 * nb00; - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - for (int i00 = 0; i00 < ne00; i00++) { - dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00])); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F32) { + // casting between non-quantized types size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - for (int i00 = 0; i00 < ne00; i00++) { - dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (ggml_get_type_traits_cpu(dst->type)->from_float) { - ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float; - float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; - - size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - for (int i00 = 0; i00 < ne00; i00++) { - src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]); - } - - quantize_row_q(src0_f32, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_BF16) { - size_t id = 0; - ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr)); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } - return; - } - - // dst counters - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - if (dst->type == GGML_TYPE_BF16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t)); - - if (++i10 == ne00) { - i10 = 0; - if (++i11 == ne01) { - i11 = 0; - if (++i12 == ne02) { - i12 = 0; - if (++i13 == ne03) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_F16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr)); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_F32) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } -} - -static void ggml_compute_forward_dup_f32( - const ggml_compute_params * params, - ggml_tensor * dst) { - - const ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - - GGML_TENSOR_UNARY_OP_LOCALS - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - // parallelize by rows - const int nr = ne01; - // number of rows per thread - const int dr = (nr + nth - 1) / nth; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { - // copy by rows - const size_t rs = ne00*nb00; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } - - if (ggml_is_contiguous(dst)) { - // TODO: simplify - if (nb00 == sizeof(float)) { - if (ggml_get_type_traits_cpu(dst->type)->from_float) { - ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float; - - size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - from_float(src0_ptr, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_BF16) { - size_t id = 0; - ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } - - return; - } - - // dst counters - - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - if (dst->type == GGML_TYPE_F32) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, sizeof(float)); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } + dst_t * dst_ptr = (dst_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + float tmp = type_conversion_table::to_f32(src0_ptr[i00]); + dst_ptr[id] = type_conversion_table::from_f32(tmp); + id++; } } + id += ne00 * (ne01 - ir1); } } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + size_t id = 0; + dst_t * dst_ptr = (dst_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + float tmp = type_conversion_table::to_f32(*src0_ptr); + dst_ptr[id] = type_conversion_table::from_f32(tmp); + id++; } } + id += ne00 * (ne01 - ir1); } } } - } else if (dst->type == GGML_TYPE_F16) { + return; + } + + // dst counters + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if constexpr (std::is_same_v) { for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { i10 += ne00 * ir0; @@ -864,15 +175,15 @@ static void ggml_compute_forward_dup_f32( const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr); + memcpy(dst_ptr, src0_ptr, sizeof(dst_t)); - if (++i10 == ne0) { + if (++i10 == ne00) { i10 = 0; - if (++i11 == ne1) { + if (++i11 == ne01) { i11 = 0; - if (++i12 == ne2) { + if (++i12 == ne02) { i12 = 0; - if (++i13 == ne3) { + if (++i13 == ne03) { i13 = 0; } } @@ -895,7 +206,8 @@ static void ggml_compute_forward_dup_f32( } } } - } else if (dst->type == GGML_TYPE_BF16) { + + } else { for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { i10 += ne00 * ir0; @@ -916,7 +228,8 @@ static void ggml_compute_forward_dup_f32( const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr); + float tmp = type_conversion_table::to_f32(*(const src_t *) src0_ptr); + *(dst_t *) dst_ptr = type_conversion_table::from_f32(tmp); if (++i10 == ne0) { i10 = 0; @@ -947,8 +260,63 @@ static void ggml_compute_forward_dup_f32( } } } + } +} + + +template +static void ggml_compute_forward_dup_to_q( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(!ggml_is_quantized(src0->type)); + + GGML_TENSOR_UNARY_OP_LOCALS + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (ggml_is_contiguous(dst) && + nb00 == sizeof(src_t) && + ggml_get_type_traits_cpu(dst->type)->from_float) { + // casting non-quantized types --> intermediate f32 --> quantized + ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float; + float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + size_t id = 0; + size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + for (int i00 = 0; i00 < ne00; i00++) { + src0_f32[i00] = type_conversion_table::to_f32(src0_ptr[i00]); + } + + quantize_row_q(src0_f32, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } } else { - GGML_ABORT("fatal error"); // TODO: implement + // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type)); + GGML_ABORT("not implemented"); } } @@ -1102,7 +470,7 @@ static void ggml_compute_forward_dup_bytes( } } -static void ggml_compute_forward_dup_q( +static void ggml_compute_forward_dup_from_q( const ggml_compute_params * params, ggml_tensor * dst) { @@ -1167,20 +535,35 @@ void ggml_compute_forward_dup( switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_dup_f16(params, dst); + /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt(params, dst); + else ggml_compute_forward_dup_to_q(params, dst); } break; case GGML_TYPE_BF16: { - ggml_compute_forward_dup_bf16(params, dst); + /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt(params, dst); + else ggml_compute_forward_dup_to_q(params, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_dup_f32(params, dst); + /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt(params, dst); + else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt(params, dst); + else ggml_compute_forward_dup_to_q(params, dst); + } break; + case GGML_TYPE_I32: + { + if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt(params, dst); + else GGML_ABORT("not implemented"); } break; default: { if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) { - ggml_compute_forward_dup_q(params, dst); + ggml_compute_forward_dup_from_q(params, dst); break; } GGML_ABORT("fatal error"); @@ -7027,6 +6410,209 @@ void ggml_compute_forward_im2col_back_f32( } } + +// ggml_compute_forward_im2col_3d_f16 +// src0: kernel [OC*IC, KD, KH, KW] +// src1: image [N*IC, ID, IH, IW] +// dst: result [N*OD, OH, OW, IC * KD * KH * KW] +static void ggml_compute_forward_im2col_3d_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t OC = ne03 / IC; + GGML_UNUSED(OC); + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + const int64_t OH_OW = OH*OW; + const int64_t KD_KH_KW = KD*KH*KW; + const int64_t KH_KW = KH*KW; + const int64_t IC_KD_KH_KW = IC*KD*KH*KW; + + GGML_ASSERT(nb10 == sizeof(float)); + + // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iod = 0; iod < OD; iod++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + + // micro kernel + ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW] + const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW] + + for (int64_t ikd = 0; ikd < KD; ikd++) { + for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + const int64_t iid = iod*s2 + ikd*d2 - p2; + + if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; + } else { + const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW] + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s); + } + } + } + } + } + } + } + } + } + } +} + +// ggml_compute_forward_im2col_3d_f32 +// src0: kernel [OC*IC, KD, KH, KW] +// src1: image [N*IC, ID, IH, IW] +// dst: result [N*OD, OH, OW, IC * KD * KH * KW] +static void ggml_compute_forward_im2col_3d_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t OC = ne03 / IC; + GGML_UNUSED(OC); + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + + const int64_t OH_OW = OH*OW; + const int64_t KD_KH_KW = KD*KH*KW; + const int64_t KH_KW = KH*KW; + const int64_t IC_KD_KH_KW = IC*KD*KH*KW; + + GGML_ASSERT(nb10 == sizeof(float)); + + // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] + { + float * const wdata = (float *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iod = 0; iod < OD; iod++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + + // micro kernel + float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW] + const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW] + + for (int64_t ikd = 0; ikd < KD; ikd++) { + for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + const int64_t iid = iod*s2 + ikd*d2 - p2; + + if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; + } else { + const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW] + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s; + } + } + } + } + } + } + } + } + } + } +} + + +void ggml_compute_forward_im2col_3d( + const ggml_compute_params * params, + ggml_tensor * dst) { + switch (dst->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_im2col_3d_f16(params, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_im2col_3d_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k, void * a, void * b, float * c) { const ggml_type_traits * traits = ggml_get_type_traits(type); @@ -7207,6 +6793,148 @@ void ggml_compute_forward_conv_2d( ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type); } +// ggml_compute_forward_conv_3d + +static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params, + const ggml_tensor * kernel, + const ggml_tensor * src, + ggml_tensor * dst, + ggml_type kernel_type) { + + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32); + GGML_ASSERT(kernel->type == kernel_type); + + const ggml_type_traits * traits = ggml_get_type_traits(kernel_type); + + const int32_t s0 = dst->op_params[0]; + const int32_t s1 = dst->op_params[1]; + const int32_t s2 = dst->op_params[2]; + const int32_t p0 = dst->op_params[3]; + const int32_t p1 = dst->op_params[4]; + const int32_t p2 = dst->op_params[5]; + const int32_t d0 = dst->op_params[6]; + const int32_t d1 = dst->op_params[7]; + const int32_t d2 = dst->op_params[8]; + const int32_t c = dst->op_params[9]; + const int32_t n = dst->op_params[10]; + const int32_t oc = dst->op_params[11]; + + const int64_t src_w = src->ne[0]; + const int64_t src_h = src->ne[1]; + const int64_t src_d = src->ne[2]; + const int64_t knl_w = kernel->ne[0]; + const int64_t knl_h = kernel->ne[1]; + const int64_t knl_d = kernel->ne[2]; + const int64_t dst_w = dst->ne[0]; + const int64_t dst_h = dst->ne[1]; + const int64_t dst_d = dst->ne[2]; + + const float * src_data = (float *) src->data; + void * knl_data = kernel->data; + float * dst_data = (float *) dst->data; + + const int64_t knl_n_per_channel = knl_w * knl_h * knl_d; + const int64_t knl_n_total = knl_n_per_channel * c; + const int64_t patch_total = n * dst_w * dst_h * dst_d; + + const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float); + const int64_t batch_size = params->wsize / space_per_patch; + const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size; + const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch; + + GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1); + + void * tmp = params->wdata; + + for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) { + const int64_t patch_start_batch = batch_i * patches_per_batch; + const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total); + const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch; + + const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth; + const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread; + const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch); + + for (int64_t p = patch_start; p < patch_end; ++p) { + const int64_t p_in_batch = p % (dst_w * dst_h * dst_d); + const int64_t p_in_depth = p_in_batch % (dst_w * dst_h); + const int64_t batch_idx = p / (dst_w * dst_h * dst_d); + const int64_t dst_z = p_in_batch / (dst_w * dst_h); + const int64_t dst_y = p_in_depth / dst_w; + const int64_t dst_x = p_in_depth % dst_w; + + char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size; + + for (int64_t ic = 0; ic < c; ++ic) { + for (int64_t kz = 0; kz < knl_d; ++kz) { + for (int64_t ky = 0; ky < knl_h; ++ky) { + for (int64_t kx = 0; kx < knl_w; ++kx) { + const int64_t sz = dst_z * s2 + kz * d2 - p2; + const int64_t sy = dst_y * s1 + ky * d1 - p1; + const int64_t sx = dst_x * s0 + kx * d0 - p0; + + int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx; + + float src_val; + if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) { + src_val = 0.0f; + } else { + const int64_t cn_idx = batch_idx * c + ic; + const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]); + src_val = *src_ptr; + } + + char * element_ptr = dst_row + dst_idx * traits->type_size; + if (kernel_type == GGML_TYPE_F32) { + *(float *)element_ptr = src_val; + } else if (kernel_type == GGML_TYPE_F16) { + *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val); + } + } + } + } + } + } + + ggml_barrier(params->threadpool); + + float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size); + ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output); + + ggml_barrier(params->threadpool); + + const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth; + const int64_t permute_start = params->ith * permute_per_thread; + const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch); + + for (int64_t i = permute_start; i < permute_end; ++i) { + const int64_t p = patch_start_batch + i; + const int64_t p_in_batch = p % (dst_w * dst_h * dst_d); + const int64_t p_in_depth = p_in_batch % (dst_w * dst_h); + const int64_t batch_idx = p / (dst_w * dst_h * dst_d); + const int64_t dst_z = p_in_batch / (dst_w * dst_h); + const int64_t dst_y = p_in_depth / dst_w; + const int64_t dst_x = p_in_depth % dst_w; + + for (int64_t ioc = 0; ioc < oc; ++ioc) { + const float value = gemm_output[i * oc + ioc]; + const int64_t ocn_idx = batch_idx * oc + ioc; + float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]); + *dst_ptr = value; + } + } + } +} + +void ggml_compute_forward_conv_3d( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type); +} + // ggml_compute_forward_conv_transpose_2d void ggml_compute_forward_conv_transpose_2d( @@ -7872,6 +7600,15 @@ static void ggml_compute_forward_pad_f32( GGML_TENSOR_UNARY_OP_LOCALS float * dst_ptr = (float *) dst->data; + const int32_t lp0 = ggml_get_op_params_i32(dst, 0); + const int32_t rp0 = ggml_get_op_params_i32(dst, 1); + const int32_t lp1 = ggml_get_op_params_i32(dst, 2); + const int32_t rp1 = ggml_get_op_params_i32(dst, 3); + const int32_t lp2 = ggml_get_op_params_i32(dst, 4); + const int32_t rp2 = ggml_get_op_params_i32(dst, 5); + const int32_t lp3 = ggml_get_op_params_i32(dst, 6); + const int32_t rp3 = ggml_get_op_params_i32(dst, 7); + // TODO: optimize @@ -7880,10 +7617,12 @@ static void ggml_compute_forward_pad_f32( for (int64_t i0 = 0; i0 < ne0; ++i0) { for (int64_t i3 = 0; i3 < ne3; ++i3) { const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; - - const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + if ((i0 >= lp0 && i0 < ne0 - rp0) \ + && (i1 >= lp1 && i1 < ne1 - rp1) \ + && (i2 >= lp2 && i2 < ne2 - rp2) \ + && (i3 >= lp3 && i3 < ne3 - rp3)) { + const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00; + const float * src_ptr = (const float *)((char *) src0->data + src_idx); dst_ptr[dst_idx] = *src_ptr; } else { dst_ptr[dst_idx] = 0; @@ -8082,7 +7821,7 @@ static void ggml_compute_forward_timestep_embedding_f32( embed_data[j + half] = sinf(arg); } if (dim % 2 != 0 && ith == 0) { - embed_data[dim] = 0.f; + embed_data[2 * half] = 0.f; } } } @@ -8861,8 +8600,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); - // allows optimizing the modulo since n_group should be a power of 2 - GGML_ASSERT((ng & -ng) == ng); + GGML_ASSERT(nh % ng == 0); // heads per thread const int dh = (nh + nth - 1)/nth; @@ -8893,6 +8631,7 @@ static void ggml_compute_forward_ssm_scan_f32( // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; const float dA = expf(dt_soft_plus * A[h]); + const int g = h / (nh / ng); // repeat_interleave // dim for (int i1 = 0; i1 < nr; ++i1) { @@ -8915,8 +8654,8 @@ static void ggml_compute_forward_ssm_scan_f32( // TODO: maybe unroll more? for (int j = 0; j < 1; j++) { GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc); - GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc); - GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc); + GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc); t0 = GGML_F32_VEC_MUL(t0, adA); t1 = GGML_F32_VEC_MUL(t1, axdt); @@ -8930,6 +8669,9 @@ static void ggml_compute_forward_ssm_scan_f32( } sumf = GGML_F32xt_REDUCE_ONE(sum); + #elif defined(__riscv_v_intrinsic) + // todo: RVV implementation + const int np = 0; #else const int np = (nc & ~(GGML_F32_STEP - 1)); @@ -8945,8 +8687,8 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i = 0; i < np; i += GGML_F32_STEP) { for (int j = 0; j < GGML_F32_ARR; j++) { ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc); - ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); - az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc); + az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc); ax[j] = GGML_F32_VEC_MUL(ax[j], adA); ay[j] = GGML_F32_VEC_MUL(ay[j], axdt); @@ -8968,7 +8710,7 @@ static void ggml_compute_forward_ssm_scan_f32( // d_state for (int i0 = np; i0 < nc; ++i0) { const int i = i0 + ii*nc; - const int ig = i0 + (h & (ng - 1))*nc; + const int ig = i0 + g*nc; // state = prev_state * dA + dB * x const float state = (s0[i] * dA) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) @@ -8985,6 +8727,7 @@ static void ggml_compute_forward_ssm_scan_f32( for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const int g = h / (nh / ng); // repeat_interleave // dim for (int i1 = 0; i1 < nr; ++i1) { @@ -8999,8 +8742,8 @@ static void ggml_compute_forward_ssm_scan_f32( // TODO: what happens when (d_state % svcntw()) != 0? for (int64_t k = 0; k < nc; k += svcntw()) { svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]); - svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]); - svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]); + svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]); + svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]); svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]); svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); @@ -9020,7 +8763,7 @@ static void ggml_compute_forward_ssm_scan_f32( // d_state for (int i0 = 0; i0 < nc; ++i0) { const int i = i0 + ii*nc; - const int ig = i0 + (h & (ng - 1))*nc; + const int ig = i0 + g*nc; // state = prev_state * dA + dB * x const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) @@ -9881,8 +9624,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32( int64_t h_stride_2d = head_size * head_size; #if defined(GGML_SIMD) - #if defined(__ARM_FEATURE_SVE) - // scalar Route to scalar implementation //TODO: Write SVE code + #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic) + // scalar Route to scalar implementation //TODO: Write SVE code and RVV code for (int64_t t = 0; t < T; t++) { int64_t t_offset = t * t_stride; int64_t state_offset = head_size * C * (t / (T / n_seqs)); diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 82ea79eaa51..9824a03b458 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -69,7 +69,9 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index b4ad68c9fd6..a84ba75c20b 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -18,6 +18,10 @@ #include #endif +#if defined(__riscv_v_intrinsic) +#include +#endif + #ifdef __cplusplus extern "C" { #endif @@ -94,24 +98,15 @@ extern "C" { } #elif defined(__riscv) && defined(__riscv_zfhmin) static inline float riscv_compute_fp16_to_fp32(ggml_fp16_t h) { - float f; - __asm__( - "fmv.h.x %[f], %[h]\n\t" - "fcvt.s.h %[f], %[f]" - : [f] "=&f" (f) - : [h] "r" (h) - ); - return f; + _Float16 hf; + memcpy(&hf, &h, sizeof(ggml_fp16_t)); + return hf; } static inline ggml_fp16_t riscv_compute_fp32_to_fp16(float f) { ggml_fp16_t res; - __asm__( - "fcvt.h.s %[f], %[f]\n\t" - "fmv.x.h %[h], %[f]" - : [h] "=&r" (res) - : [f] "f" (f) - ); + _Float16 hf = (_Float16)f; + memcpy(&res, &hf, sizeof(ggml_fp16_t)); return res; } @@ -119,26 +114,6 @@ extern "C" { #define GGML_CPU_COMPUTE_FP32_TO_FP16(x) riscv_compute_fp32_to_fp16(x) #define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x) #define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x) -#elif defined(__NNPA__) - #define GGML_CPU_COMPUTE_FP16_TO_FP32(x) nnpa_compute_fp16_to_fp32(x) - #define GGML_CPU_COMPUTE_FP32_TO_FP16(x) nnpa_compute_fp32_to_fp16(x) - - #define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x) - #define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x) - - static inline float nnpa_compute_fp16_to_fp32(ggml_fp16_t h) { - uint16x8_t v_h = vec_splats(h); - uint16x8_t v_hd = vec_convert_from_fp16(v_h, 0); - return vec_extend_to_fp32_hi(v_hd, 0)[0]; - } - - static inline ggml_fp16_t nnpa_compute_fp32_to_fp16(float f) { - float32x4_t v_f = vec_splats(f); - float32x4_t v_zero = vec_splats(0.0f); - uint16x8_t v_hd = vec_round_from_fp32(v_f, v_zero, 0); - uint16x8_t v_h = vec_convert_to_fp16(v_hd, 0); - return vec_extract(v_h, 0); - } #endif // precomputed f32 table for f16 (256 KB) @@ -220,6 +195,47 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { #define GGML_F32_VEC_MUL GGML_F32xt_MUL #define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE +// F16 SVE +#define DEFAULT_PG32 svptrue_b32() +#define DEFAULT_PG16 svptrue_b16() + +#define GGML_F32Cxt svfloat16_t +#define GGML_F32Cxt_ZERO svdup_n_f16(0.0f) +#define GGML_F32Cxt_SET1(x) svdup_n_f16(x) +#define GGML_F32Cxt_LOAD(p) svld1_f16(DEFAULT_PG16, (const __fp16 *)(p)) +#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec)) + +#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a) +#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b) +#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b) +#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__) +#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED + +#define GGML_F16x_VEC GGML_F32Cxt +#define GGML_F16x_VEC_ZERO GGML_F32Cxt_ZERO +#define GGML_F16x_VEC_SET1 GGML_F32Cxt_SET1 +#define GGML_F16x_VEC_LOAD(p, i) GGML_F32Cxt_LOAD(p) +#define GGML_F16x_VEC_STORE(p, r, i) GGML_F32Cxt_STORE((__fp16 *)(p), r) +#define GGML_F16x_VEC_FMA GGML_F32Cxt_FMA +#define GGML_F16x_VEC_ADD GGML_F32Cxt_ADD +#define GGML_F16x_VEC_MUL GGML_F32Cxt_MUL +#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE + +#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a) +#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__) + +#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \ +{ \ + sum1 = svadd_f16_x(pg16, sum1, sum2); \ + sum3 = svadd_f16_x(pg16, sum3, sum4); \ + sum1 = svadd_f16_x(pg16, sum1, sum3); \ + __fp16 sum_f16 = svaddv_f16(pg16, sum1); \ + (res) = (ggml_float) sum_f16; \ +} +#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__) + // F16 NEON #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) @@ -1120,11 +1136,6 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { #define GGML_F16_EPR GGML_F32_EPR static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) { -#if defined(__NNPA__) - uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)x); - uint16x8_t v_xd = vec_convert_from_fp16(v_x, 0); - return vec_extend_to_fp32_hi(v_xd, 0); -#else float tmp[4]; for (int i = 0; i < 4; i++) { @@ -1134,20 +1145,9 @@ static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) { // note: keep type-cast here to prevent compiler bugs // see: https://github.com/ggml-org/llama.cpp/issues/12846 return vec_xl(0, (const float *)(tmp)); -#endif } static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { -#if defined(__NNPA__) - float32x4_t v_zero = vec_splats(0.0f); - uint16x8_t v_xd = vec_round_from_fp32(v_y, v_zero, 0); - uint16x8_t v_x = vec_convert_to_fp16(v_xd, 0); - - x[0] = vec_extract(v_x, 0); - x[1] = vec_extract(v_x, 1); - x[2] = vec_extract(v_x, 2); - x[3] = vec_extract(v_x, 3); -#else float arr[4]; // note: keep type-cast here to prevent compiler bugs @@ -1157,7 +1157,6 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { for (int i = 0; i < 4; i++) { x[i] = GGML_CPU_FP32_TO_FP16(arr[i]); } -#endif } #define GGML_F16_VEC GGML_F32x4 @@ -1170,6 +1169,36 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { #define GGML_F16_VEC_MUL GGML_F32x4_MUL #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +#elif defined(__riscv_v_intrinsic) + +// compatible with vlen >= 128 + +#define GGML_SIMD + +// F32 + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 vfloat32m1_t +#define GGML_F32x4_ZERO __riscv_vfmv_v_f_f32m1(0.0f, GGML_F32_EPR) +#define GGML_F32x4_SET1(x) __riscv_vfmv_v_f_f32m1(x, GGML_F32_EPR) +#define GGML_F32x4_LOAD(x) __riscv_vle32_v_f32m1(x, GGML_F32_EPR) +#define GGML_F32x4_STORE(b, v) __riscv_vse32_v_f32m1(b, v, GGML_F32_EPR) +#define GGML_F32x4_FMA(a, b, c) __riscv_vfmacc_vv_f32m1(a, b, c, GGML_F32_EPR) +#define GGML_F32x4_ADD(a, b) __riscv_vfadd_vv_f32m1(a, b, GGML_F32_EPR) +#define GGML_F32x4_MUL(a, b) __riscv_vfmul_vv_f32m1(a, b, GGML_F32_EPR) + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + #endif // GGML_F32_ARR / GGML_F16_ARR diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 07b377bdd82..437192d525a 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -84,6 +84,22 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G } // reduce sum1,sum2 to sum1 GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8); + #elif defined(__riscv_v_intrinsic) + int vl = __riscv_vsetvlmax_e32m8(); + vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1); + vfloat32m8_t vsum; + vfloat32m8_t ax; + vfloat32m8_t ay; + vsum = __riscv_vfmv_v_f_f32m8_tu(vsum, 0.0f, vl); + for (int i = 0; i < n; i += vl) { + vl = __riscv_vsetvl_e32m8(n - i); + ax = __riscv_vle32_v_f32m8_tu(ax, &x[i], vl); + ay = __riscv_vle32_v_f32m8_tu(ay, &y[i], vl); + vsum = __riscv_vfmacc_vv_f32m8_tu(vsum, ax, ay, vl); + } + vl = __riscv_vsetvlmax_e32m8(); + vs = __riscv_vfredusum_vs_f32m8_f32m1(vsum, vs, vl); + sumf += __riscv_vfmv_f_s_f32m1_f32(vs); #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -197,38 +213,125 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G ggml_float sumf = 0.0; -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); - GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; +#if defined(GGML_SIMD) + #if defined(__ARM_FEATURE_SVE) + const int sve_register_length = svcntb() * 8; //get vector length + const int ggml_f16_epr = sve_register_length / 16; // running when 16 + const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers + + const int np= (n & ~(ggml_f16_step - 1)); + svfloat16_t sum1 = svdup_n_f16(0.0f); + svfloat16_t sum2 = svdup_n_f16(0.0f); + svfloat16_t sum3 = svdup_n_f16(0.0f); + svfloat16_t sum4 = svdup_n_f16(0.0f); + + svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; + svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + for (int i = 0; i < np; i += ggml_f16_step) { + ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); + sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1); + + ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); + sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2); + + ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); + sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3); + + ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); + sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4); + + ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); + sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5); + + ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); + sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6); + + ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); + sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7); + + ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); + sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8); + } - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; + const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8 + for (int k = np; k < np2; k += ggml_f16_epr) { + svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); + svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry); + } - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + if (np2 < n) { + svbool_t pg = svwhilelt_b16(np2, n); + svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); + svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); - sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + sum1 = svmad_f16_x(pg, hx, hy, sum1); } - } + GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4); + #elif defined(__riscv_v_intrinsic) + #if defined(__riscv_zvfh) + int vl = __riscv_vsetvlmax_e32m2(); + vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1); + vfloat32m2_t vsum; + vfloat16m1_t ax; + vfloat16m1_t ay; + vsum = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vmv_v_x_u32m2(0, vl)); + for (int i = 0; i < n; i += vl) { + vl = __riscv_vsetvl_e16m1(n - i); + ax = __riscv_vle16_v_f16m1_tu(ax, (const _Float16 *)&x[i], vl); + ay = __riscv_vle16_v_f16m1_tu(ay, (const _Float16 *)&y[i], vl); + vsum = __riscv_vfwmacc_vv_f32m2_tu(vsum, ax, ay, vl); + } + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t ac0 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(vsum, 0), __riscv_vget_v_f32m2_f32m1(vsum, 1), vl); + vs = __riscv_vfredusum_vs_f32m1_f32m1(ac0, vs, vl); + sumf += __riscv_vfmv_f_s_f32m1_f32(vs); + #else + for (int i = 0; i < n; ++i) { + sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); + } + #endif // __riscv_zvfh + #else + const int np = (n & ~(GGML_F16_STEP - 1)); - // reduce sum0..sum3 to sum0 - GGML_F16_VEC_REDUCE(sumf, sum); + GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; - // leftovers - for (int i = np; i < n; ++i) { - sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); - } + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - // if you hit this, you are likely running outside the FP range - assert(!isnan(sumf) && !isinf(sumf)); + sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F16_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); + } + // if you hit this, you are likely running outside the FP range + assert(!isnan(sumf) && !isinf(sumf)); + #endif #else for (int i = 0; i < n; ++i) { sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); } -#endif +#endif // GGML_SIMD *s = sumf; } @@ -247,6 +350,12 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) { for (; i + 3 < n; i += 4) { _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i))); } +#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + const int vlen = svcntw(); + for (; i < n; i += vlen) { + const svbool_t pg = svwhilelt_b32_s32(i, n); + svst1_f32(pg, y + i, ggml_v_silu(pg, svld1_f32(pg, x + i))); + } #elif defined(__ARM_NEON) && defined(__aarch64__) for (; i + 3 < n; i += 4) { vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i))); @@ -271,10 +380,24 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * for (; i + 3 < n; i += 4) { _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i))); } +#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + const int vlen = svcntw(); + for (; i < n; i += vlen) { + const svbool_t pg = svwhilelt_b32_s32(i, n); + svst1_f32(pg, y + i, svmul_f32_x(pg, ggml_v_silu(pg, svld1_f32(pg, x + i)), svld1_f32(pg, g + i))); + } #elif defined(__ARM_NEON) && defined(__aarch64__) for (; i + 3 < n; i += 4) { vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i))); } +#elif defined(__riscv_v_intrinsic) + for (int vl; i < n; i += vl) { + vl = __riscv_vsetvl_e32m2(n - i); + vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl); + vfloat32m2_t vg = __riscv_vle32_v_f32m2(&g[i], vl); + vfloat32m2_t vy = __riscv_vfmul_vv_f32m2(ggml_v_silu_m2(vx, vl), vg, vl); + __riscv_vse32_v_f32m2(&y[i], vy, vl); + } #endif for (; i < n; ++i) { y[i] = ggml_silu_f32(x[i]) * g[i]; @@ -318,6 +441,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float #endif sum += (ggml_float)_mm_cvtss_f32(val); } +#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + const int vlen = svcntw(); + for (; i < n; i += vlen) { + const svbool_t pg = svwhilelt_b32_s32(i, n); + svfloat32_t val = ggml_v_expf(pg, svsub_f32_x(pg, svld1_f32(pg, x + i), + svdup_n_f32_x(pg, max))); + svst1_f32(pg, y + i, val); + sum += (ggml_float)svaddv_f32(pg, val); + } #elif defined(__ARM_NEON) && defined(__aarch64__) for (; i + 3 < n; i += 4) { float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i), @@ -325,6 +457,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float vst1q_f32(y + i, val); sum += (ggml_float)vaddvq_f32(val); } +#elif defined(__riscv_v_intrinsic) + vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1); + for (int avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m2(n - i); + vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl); + __riscv_vse32_v_f32m2(&y[i], val, avl); + vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl); + } + return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum); #endif for (; i < n; ++i) { float val = expf(x[i] - max); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 2250d93cb00..ef334d089d1 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -119,36 +119,149 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG } #if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) + + const int sve_register_length = svcntb() * 8; + const int ggml_f16_epr = sve_register_length / 16; // running when 16 + const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers + + const int np = (n & ~(ggml_f16_step - 1)); + + svfloat16_t sum_00 = svdup_n_f16(0.0f); + svfloat16_t sum_01 = svdup_n_f16(0.0f); + svfloat16_t sum_02 = svdup_n_f16(0.0f); + svfloat16_t sum_03 = svdup_n_f16(0.0f); + + svfloat16_t sum_10 = svdup_n_f16(0.0f); + svfloat16_t sum_11 = svdup_n_f16(0.0f); + svfloat16_t sum_12 = svdup_n_f16(0.0f); + svfloat16_t sum_13 = svdup_n_f16(0.0f); + + svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; + svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + + for (int i = 0; i < np; i += ggml_f16_step) { + ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements + + ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst + sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1 + ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements + sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1); + + ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements + + ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements + sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2); + ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1); + sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2); + + ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); + + ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2); + sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3); + ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2); + sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3); + + ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); + + ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3); + sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4); + ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3); + sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4); + + ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); + + ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4); + + sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5); + ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4); + sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5); + + ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); + + ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5); + + sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6); + ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5); + sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6); + + ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); + + ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6); + + sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7); + ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6); + sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7); + + ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); + + ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7); + + sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8); + ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7); + sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8); + } + + const int np2 = (n & ~(ggml_f16_epr - 1)); + for (int k = np; k < np2; k += ggml_f16_epr) { + svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + + svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0); + sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry); + rx = GGML_F16x_VEC_LOAD(x[1] + k, 0); + sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry); + } + + if (np2 < n) { + svbool_t pg = svwhilelt_b16(np2, n); + svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2)); + svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2)); + svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + + sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00); + sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10); + } + GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03); + GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13); + #elif defined(__riscv_v_intrinsic) + // todo: RVV impl + for (int i = 0; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); + } + } + #else + const int np = (n & ~(GGML_F16_STEP - 1)); - GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; + GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { - ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); - sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); + sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); + } } } - } - // reduce sum0..sum3 to sum0 - for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { - GGML_F16_VEC_REDUCE(sumf[k], sum[k]); - } + // reduce sum0..sum3 to sum0 + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + GGML_F16_VEC_REDUCE(sumf[k], sum[k]); + } - // leftovers - for (int i = np; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); + // leftovers + for (int i = np; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); + } } - } + #endif #else for (int i = 0; i < n; ++i) { for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { @@ -243,6 +356,14 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const svst1_f32(pg, y + np2, ay1); } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, v, ay, avl); + __riscv_vse32_v_f32m8(&y[i], ny, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -276,27 +397,112 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) { #if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) + const int sve_register_length = svcntb() * 8; + const int ggml_f16_epr = sve_register_length / 16; + const int ggml_f16_step = 8 * ggml_f16_epr; + + GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); + + const int np= (n & ~(ggml_f16_step - 1)); + + svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; + svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + for (int i = 0; i < np; i += ggml_f16_step) { + ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx); + + GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0); + + ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx); + + GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1); + + ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); + ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx); + + GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2); + + ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); + ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx); + + GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3); + + ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); + ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx); + + GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4); + + ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); + ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx); + + GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5); - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); + ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx); - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; + GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6); - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); + ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx); - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7); } - } + const int np2 = (n & ~(ggml_f16_epr - 1)); + for (int k = np; k < np2; k += ggml_f16_epr) { + svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); + svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + ry = GGML_F16x_VEC_FMA(ry, rx, vx); - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); - } + GGML_F16x_VEC_STORE(y + k, ry, 0); + } + + if (np2 < n) { + svbool_t pg = svwhilelt_b16(np2, n); + svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); + svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + hy = svmad_f16_x(pg, hx, vx, hy); + svst1_f16(pg, (__fp16 *)(y + np2), hy); + } + + #elif defined(__riscv_v_intrinsic) + // todo: RVV impl + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); + } + #else + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); + } + #endif #else // scalar for (int i = 0; i < n; ++i) { @@ -324,6 +530,16 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int y[i] += x[k][i]*v[k][0]; } } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + for (int k = 0; k < GGML_VEC_MAD_UNROLL; k++) { + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[k][i], avl); + ay = __riscv_vfmadd_vf_f32m8(ax, v[k][0], ay, avl); + } + __riscv_vse32_v_f32m8(&y[i], ay, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -375,6 +591,14 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co for (int i = 0; i < n; ++i) { y[i] = x[i]*s + b; } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl); + vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl); + vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl); + __riscv_vse32_v_f32m8(&y[i], ny, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -436,6 +660,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { ay1 = svmul_f32_m(pg, ay1, vx); svst1_f32(pg, y + np, ay1); } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); + vfloat32m8_t ny = __riscv_vfmul_vf_f32m8(ay, v, avl); + __riscv_vse32_v_f32m8(&y[i], ny, avl); + } #else const int np = (n & ~(GGML_F32_STEP - 1)); @@ -467,25 +698,59 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { #if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) + const int sve_register_length = svcntb() * 8; + const int ggml_f16_epr = sve_register_length / 16; + const int ggml_f16_step = 2 * ggml_f16_epr; + + GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); + const int np = (n & ~(ggml_f16_step - 1)); + svfloat16_t ay1, ay2; + + for (int i = 0; i < np; i += ggml_f16_step) { + ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0); + ay1 = GGML_F16x_VEC_MUL(ay1, vx); + GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0); + + ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1); + ay2 = GGML_F16x_VEC_MUL(ay2, vx); + GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1); + } + // leftovers + // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only + if (np < n) { + svbool_t pg = svwhilelt_b16(np, n); + svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np)); + svfloat16_t out = svmul_f16_m(pg, hy, vx); + svst1_f16(pg, (__fp16 *)(y + np), out); + } + #elif defined(__riscv_v_intrinsic) + // todo: RVV impl + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); + } + #else + const int np = (n & ~(GGML_F16_STEP - 1)); - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); - GGML_F16_VEC ay[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } } - } - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); - } + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); + } + #endif #else // scalar for (int i = 0; i < n; ++i) { @@ -737,7 +1002,39 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) { } #endif -#if defined(__ARM_NEON) && defined(__aarch64__) +#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + +inline static svfloat32_t ggml_v_expf(svbool_t pg, svfloat32_t x) { + const svfloat32_t r = svdup_n_f32_x(pg, 0x1.8p23f); + const svfloat32_t z = svmla_n_f32_x(pg, r, x, 0x1.715476p+0f); + const svfloat32_t n = svsub_f32_x(pg, z, r); + const svfloat32_t b = svmls_n_f32_x(pg, svmls_n_f32_x(pg, x, n, 0x1.62e4p-1f), n, 0x1.7f7d1cp-20f); + const svuint32_t e = svlsl_n_u32_x(pg, svreinterpret_u32_f32(z), 23); + const svfloat32_t k = svreinterpret_f32_u32(svadd_u32_x(pg, e, svreinterpret_u32_f32(svdup_n_f32_x(pg, 1)))); + const svbool_t c = svacgt_n_f32(pg, n, 126); + const svfloat32_t u = svmul_f32_x(pg, b, b); + const svfloat32_t j = svmla_f32_x(pg, + svmul_n_f32_x(pg, b, 0x1.ffffecp-1f), + svmla_f32_x(pg, svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.fffdb6p-2f), svdup_n_f32_x(pg, 0x1.555e66p-3f), b), + svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.573e2ep-5f), svdup_n_f32_x(pg, 0x1.0e4020p-7f), b), u), u); + const svuint32_t d = svdup_n_u32_z(svcmple_n_f32(pg, n, 0.0), 0x82000000); + const svfloat32_t s1 = svreinterpret_f32_u32(svadd_n_u32_x(pg, d, 0x7f000000)); + const svfloat32_t s2 = svreinterpret_f32_u32(svsub_u32_x(pg, e, d)); + return svsel_f32(svacgt_f32(pg, n, svdup_n_f32_x(pg, 192)), svmul_f32_x(pg, s1, s1), + svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j))); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) { + const svfloat32_t one = svdup_n_f32_x(pg, 1.0f); + const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f); + const svfloat32_t neg_x = svsub_f32_x(pg, zero, x); + const svfloat32_t exp_neg_x = ggml_v_expf(pg, neg_x); + const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x); + return svdiv_f32_x(pg, x, one_plus_exp_neg_x); +} + +#elif defined(__ARM_NEON) && defined(__aarch64__) // adapted from arm limited optimized routine // the maximum error is 1.45358 plus 0.5 ulps @@ -928,7 +1225,59 @@ inline static __m128 ggml_v_silu(__m128 x) { return _mm_div_ps(x, one_plus_exp_neg_x); } -#endif // __ARM_NEON / __AVX2__ / __SSE2__ +#elif defined(__riscv_v_intrinsic) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) { + const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl); +#ifdef __riscv_xtheadvector + // workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4') + vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl); + z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl); +#else + const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl); +#endif + const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl); + const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl), + 0x1.7f7d1cp-20f, n, vl); + const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl); + const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f + const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl); + const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl); + const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2( + __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl), + __riscv_vfmacc_vv_f32m2( + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl), + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl), + u, vl), u, vl); + if (!__riscv_vcpop_m_b16(c, vl)) + return __riscv_vfmacc_vv_f32m2(k, j, k, vl); + const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl); + const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl); + const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl)); + const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl)); + const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2( + __riscv_vfmacc_vv_f32m2(k, k, j, vl), + __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl), + c, vl); + return __riscv_vmerge_vvm_f32m2( + r1, __riscv_vfmul_vv_f32m2(s1, s1, vl), + __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl), + vl); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) { + const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl); + const vfloat32m2_t exp_neg_x = ggml_v_expf_m2(neg_x, vl); + const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl); + return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl); +} + +#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) { diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index bce07ac3628..bdcefe7b7ed 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -24,17 +24,15 @@ if (CUDAToolkit_FOUND) # for best performance and to also build real architectures for the most commonly used GPUs. if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") set(CMAKE_CUDA_ARCHITECTURES "native") - elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8") - set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real") - else() - set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real") - endif() else() + if (CUDAToolkit_VERSION VERSION_LESS "13") + list(APPEND CMAKE_CUDA_ARCHITECTURES 50-virtual 61-virtual 70-virtual) + endif () + + list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8") - set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real") - else() - set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real") + list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real) endif() endif() endif() @@ -50,6 +48,8 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/mmf*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) if (GGML_CUDA_FA_ALL_QUANTS) file(GLOB SRCS "template-instances/fattn-vec*.cu") @@ -91,10 +91,6 @@ if (CUDAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_FA) endif() - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - add_compile_definitions(GGML_CUDA_F16) - endif() - if (GGML_CUDA_NO_PEER_COPY) add_compile_definitions(GGML_CUDA_NO_PEER_COPY) endif() @@ -104,7 +100,11 @@ if (CUDAToolkit_FOUND) # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas) else () - target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1") + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + else() + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static) + endif() endif() else() target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas) diff --git a/ggml/src/ggml-cuda/add-id.cu b/ggml/src/ggml-cuda/add-id.cu index 8bed62ac9d2..8d9cf692b4b 100644 --- a/ggml/src/ggml-cuda/add-id.cu +++ b/ggml/src/ggml-cuda/add-id.cu @@ -11,14 +11,14 @@ static __global__ void add_id_kernel( const int64_t i1 = blockIdx.x; const int64_t i2 = blockIdx.y; - const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21); + const int i11 = *(const int32_t *) ((const char *) src2 + i1*sizeof(int32_t) + i2*nb21); const size_t nb1 = ne0 * sizeof(float); const size_t nb2 = ne1 * nb1; float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2); - const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02); - const float * src1_row = (const float *)((char *)src1 + i11*nb11); + const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02); + const float * src1_row = (const float *)((const char *)src1 + i11*nb11); for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { dst_row[i0] = src0_row[i0] + src1_row[i0]; diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index e1fbf0e1366..725e1a81a1f 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -1,5 +1,6 @@ #include "binbcast.cuh" #include +#include static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; @@ -22,73 +23,295 @@ static __device__ __forceinline__ float op_div(const float a, const float b) { return a / b; } -template -static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13) { - const int i0s = blockDim.x*blockIdx.x + threadIdx.x; - const int i1 = (blockDim.y*blockIdx.y + threadIdx.y); - const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3; - const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3; - - if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { +template +static __global__ void k_bin_bcast(const src0_t * src0, + const src1_t * src1, + dst_t * dst, + const int ne0, + const int ne1, + const int ne2, + const uint3 ne3, + const uint3 ne10, + const uint3 ne11, + const uint3 ne12, + const uint3 ne13, + /*int s0, */ const int s1, + const int s2, + const int s3, + /*int s00,*/ const int s01, + const int s02, + const int s03, + /*int s10,*/ const int s11, + const int s12, + const int s13, + src1_ptrs... src1s) { + const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y); + const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3); + const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z); + + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) { return; } - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; + const uint32_t i11 = fastmodulo(i1, ne11); + const uint32_t i12 = fastmodulo(i2, ne12); + const uint32_t i13 = fastmodulo(i3, ne13); const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; + const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; - for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) { - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { + const uint32_t i10 = fastmodulo(i0, ne10); + + float result = src0_row ? (float) src0_row[i0] : 0.0f; + if constexpr (sizeof...(src1_ptrs) > 0) { + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + } else { + result = bin_op(result, (float)src1[i_src1 + i10]); + } + + dst_row[i0] = (dst_t) result; } } -template -static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13) { - +template +static __global__ void k_bin_bcast_unravel(const src0_t * src0, + const src1_t * src1, + dst_t * dst, + const uint3 ne0, + const uint3 ne1, + const uint3 ne2, + const uint32_t ne3, + const uint3 prod_012, + const uint3 prod_01, + const uint3 ne10, + const uint3 ne11, + const uint3 ne12, + const uint3 ne13, + /*int s0, */ const int s1, + const int s2, + const int s3, + /*int s00,*/ const int s01, + const int s02, + const int s03, + /*int s10,*/ const int s11, + const int s12, + const int s13, + src1_ptrs... src1s) { const int i = blockDim.x*blockIdx.x + threadIdx.x; - const int i3 = i/(ne2*ne1*ne0); - const int i2 = (i/(ne1*ne0)) % ne2; - const int i1 = (i/ne0) % ne1; - const int i0 = i % ne0; + const uint32_t i3 = fastdiv(i, prod_012); + const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01); + const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0); + const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z; - if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) { return; } - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; + const int i11 = fastmodulo(i1, ne11); + const int i12 = fastmodulo(i2, ne12); + const int i13 = fastmodulo(i3, ne13); const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; + const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + const int i10 = fastmodulo(i0, ne10); + + float result = src0_row ? (float) src0_row[i0] : 0.0f; + if constexpr (sizeof...(src1_ptrs) > 0) { + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + } else { + result = bin_op(result, (float)src1[i_src1 + i10]); + } + + dst_row[i0] = (dst_t) result; +} + +template +static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, + cudaStream_t stream, std::index_sequence) { + GGML_TENSOR_BINARY_OP_LOCALS + + int nr0 = ne10 / ne0; + int nr1 = ne11 / ne1; + int nr2 = ne12 / ne2; + int nr3 = ne13 / ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + int64_t cne[] = { ne0, ne1, ne2, ne3 }; + int64_t cne0[] = { ne00, ne01, ne02, ne03 }; + int64_t cne1[] = { ne10, ne11, ne12, ne13 }; + + size_t cnb[] = { nb0, nb1, nb2, nb3 }; + size_t cnb0[] = { nb00, nb01, nb02, nb03 }; + size_t cnb1[] = { nb10, nb11, nb12, nb13 }; + + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; + }; + + auto collapse_nb = [](size_t cnb[], const int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb, cne); + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne); + collapse(cne0); + collapse(cne1); + } + } + } + + { + int64_t ne0 = cne[0]; + int64_t ne1 = cne[1]; + int64_t ne2 = cne[2]; + int64_t ne3 = cne[3]; + + //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00); + //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01); + //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); + //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); + + size_t nb0 = cnb[0]; + size_t nb1 = cnb[1]; + size_t nb2 = cnb[2]; + size_t nb3 = cnb[3]; + + size_t nb00 = cnb0[0]; + size_t nb01 = cnb0[1]; + size_t nb02 = cnb0[2]; + size_t nb03 = cnb0[3]; + + size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; + + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + size_t s00 = nb00 / sizeof(src0_t); + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s00 == 1); + GGML_ASSERT(s10 == 1); + + const int block_size = 128; + + int64_t hne0 = std::max(ne0 / 2LL, 1LL); + + dim3 block_dims; + block_dims.x = std::min(hne0, block_size); + block_dims.y = std::min(ne1, block_size / block_dims.x); + block_dims.z = std::min(std::min(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U); + + dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y, + (ne2 * ne3 + block_dims.z - 1) / block_dims.z); + + const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]); + const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]); + const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]); + const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]); + + if (block_nums.z > 65535) { + int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2)); + const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1)); + const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0); + const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1); + const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2); + + if constexpr (sizeof...(I) > 0) { + k_bin_bcast_unravel<<>>( + src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, + ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + } else { + k_bin_bcast_unravel + <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, + ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13); + } + } else { + const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); + if constexpr (sizeof...(I) > 0) { + k_bin_bcast<<>>( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + } else { + k_bin_bcast<<>>( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13); + } + } + } } template @@ -120,160 +343,14 @@ static __global__ void k_repeat_back( dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; } -template +template struct bin_bcast_cuda { template void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, cudaStream_t stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - - int nr0 = ne10/ne0; - int nr1 = ne11/ne1; - int nr2 = ne12/ne2; - int nr3 = ne13/ne3; - - int nr[4] = { nr0, nr1, nr2, nr3 }; - - // collapse dimensions until first broadcast dimension - int64_t cne[] = {ne0, ne1, ne2, ne3}; - int64_t cne0[] = {ne00, ne01, ne02, ne03}; - int64_t cne1[] = {ne10, ne11, ne12, ne13}; - - size_t cnb[] = {nb0, nb1, nb2, nb3}; - size_t cnb0[] = {nb00, nb01, nb02, nb03}; - size_t cnb1[] = {nb10, nb11, nb12, nb13}; - - auto collapse = [](int64_t cne[]) { - cne[0] *= cne[1]; - cne[1] = cne[2]; - cne[2] = cne[3]; - cne[3] = 1; - }; - - auto collapse_nb = [](size_t cnb[], const int64_t cne[]) { - cnb[1] *= cne[1]; - cnb[2] *= cne[2]; - cnb[3] *= cne[3]; - }; - - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - for (int i = 0; i < 4; i++) { - if (nr[i] != 1) { - break; - } - if (i > 0) { - collapse_nb(cnb, cne); - collapse_nb(cnb0, cne0); - collapse_nb(cnb1, cne1); - collapse(cne); - collapse(cne0); - collapse(cne1); - } - } - } - - { - int64_t ne0 = cne[0]; - int64_t ne1 = cne[1]; - int64_t ne2 = cne[2]; - int64_t ne3 = cne[3]; - - //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00); - //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01); - //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); - //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); - - int64_t ne10 = cne1[0]; - int64_t ne11 = cne1[1]; - int64_t ne12 = cne1[2]; - int64_t ne13 = cne1[3]; - - size_t nb0 = cnb[0]; - size_t nb1 = cnb[1]; - size_t nb2 = cnb[2]; - size_t nb3 = cnb[3]; - - size_t nb00 = cnb0[0]; - size_t nb01 = cnb0[1]; - size_t nb02 = cnb0[2]; - size_t nb03 = cnb0[3]; - - size_t nb10 = cnb1[0]; - size_t nb11 = cnb1[1]; - size_t nb12 = cnb1[2]; - size_t nb13 = cnb1[3]; - - size_t s0 = nb0 / sizeof(dst_t); - size_t s1 = nb1 / sizeof(dst_t); - size_t s2 = nb2 / sizeof(dst_t); - size_t s3 = nb3 / sizeof(dst_t); - - size_t s10 = nb10 / sizeof(src1_t); - size_t s11 = nb11 / sizeof(src1_t); - size_t s12 = nb12 / sizeof(src1_t); - size_t s13 = nb13 / sizeof(src1_t); - - size_t s00 = nb00 / sizeof(src0_t); - size_t s01 = nb01 / sizeof(src0_t); - size_t s02 = nb02 / sizeof(src0_t); - size_t s03 = nb03 / sizeof(src0_t); - - GGML_ASSERT(nb0 % sizeof(dst_t) == 0); - GGML_ASSERT(nb1 % sizeof(dst_t) == 0); - GGML_ASSERT(nb2 % sizeof(dst_t) == 0); - GGML_ASSERT(nb3 % sizeof(dst_t) == 0); - - GGML_ASSERT(nb00 % sizeof(src0_t) == 0); - GGML_ASSERT(nb01 % sizeof(src0_t) == 0); - GGML_ASSERT(nb02 % sizeof(src0_t) == 0); - GGML_ASSERT(nb03 % sizeof(src0_t) == 0); - - GGML_ASSERT(nb10 % sizeof(src1_t) == 0); - GGML_ASSERT(nb11 % sizeof(src1_t) == 0); - GGML_ASSERT(nb12 % sizeof(src1_t) == 0); - GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s00 == 1); - GGML_ASSERT(s10 == 1); - - const int block_size = 128; - - int64_t hne0 = std::max(ne0/2LL, 1LL); - - dim3 block_dims; - block_dims.x = std::min(hne0, block_size); - block_dims.y = std::min(ne1, block_size / block_dims.x); - block_dims.z = std::min(std::min(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U); - - dim3 block_nums( - (hne0 + block_dims.x - 1) / block_dims.x, - (ne1 + block_dims.y - 1) / block_dims.y, - (ne2*ne3 + block_dims.z - 1) / block_dims.z - ); - - if (block_nums.z > 65535) { - // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel - int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; - k_bin_bcast_unravel<<>>( - src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00, */ s01, s02, s03, - /* s10, */ s11, s12, s13); - } else { - k_bin_bcast<<>>( - src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00, */ s01, s02, s03, - /* s10, */ s11, s12, s13); - } - } + launch_bin_bcast_pack( + src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence{}); } }; @@ -312,7 +389,7 @@ static void ggml_cuda_op_bin_bcast( } void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); + ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); } void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -331,6 +408,68 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } +template +static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + cudaStream_t stream = ctx.stream(); + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + launch_bin_bcast_pack(src0, src1, dst, + (const float *) src0->data, (const float *) src1->data, (float *) dst->data, + stream, std::make_index_sequence{}); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + launch_bin_bcast_pack(src0, src1, dst, + (const half *) src0->data, (const half *) src1->data, (half *) dst->data, + stream, std::make_index_sequence{}); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + launch_bin_bcast_pack(src0, src1, dst, + (const half *) src0->data, (const float *) src1->data, (half *) dst->data, + stream, std::make_index_sequence{}); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + launch_bin_bcast_pack(src0, src1, dst, + (const half *) src0->data, (const float *) src1->data, (float *) dst->data, + stream, std::make_index_sequence{}); + } else { + fprintf(stderr, + "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n", + __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } +} + + +void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) { + GGML_ASSERT(2 <= n_fuse && n_fuse <= 8); + + switch (n_fuse) { + case 2: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 3: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 4: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 5: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 6: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 7: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 8: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + default: + GGML_ASSERT(false && "Unsupported n_fuse value"); + } +} + void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh index 3ac1c9b03fc..62bc950111b 100644 --- a/ggml/src/ggml-cuda/binbcast.cuh +++ b/ggml/src/ggml-cuda/binbcast.cuh @@ -7,3 +7,5 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2b14b30ac90..3b1349171b2 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -75,9 +75,13 @@ #define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1) #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2) +#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3) #define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) // Moore Threads +#define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons + #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 #define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 #define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD @@ -105,9 +109,9 @@ constexpr bool ggml_cuda_has_arch(const int arch) { return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__); } -constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) { +constexpr int ggml_cuda_highest_compiled_arch_impl(const int /*arch*/, const int cur) { if (cur == 0) { - GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch); + return -1; } return cur; } @@ -204,14 +208,6 @@ static const char * cu_get_error_str(CUresult err) { #define GGML_CUDA_ASSUME(x) #endif // CUDART_VERSION >= 11010 -#ifdef GGML_CUDA_F16 -typedef half dfloat; // dequantize float -typedef half2 dfloat2; -#else -typedef float dfloat; // dequantize float -typedef float2 dfloat2; -#endif // GGML_CUDA_F16 - #if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) #define GGML_USE_VMM #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) @@ -331,6 +327,20 @@ static constexpr __device__ int ggml_cuda_get_physical_warp_size() { #endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) } +// Maximum number of bytes that can be copied in a single instruction. +static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() { +#ifdef GGML_USE_HIP + return 16; +#else +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + return 16; +#else + return 8; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // GGML_USE_HIP +} + + [[noreturn]] static __device__ void no_device_code( const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { @@ -426,16 +436,28 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { template static __device__ __forceinline__ int warp_reduce_all(int x) { -#ifdef GGML_USE_HIP + if (width == ggml_cuda_get_physical_warp_size()) { + return __all_sync(0xffffffff, x); + } else { #pragma unroll - for (int offset = width/2; offset > 0; offset >>= 1) { - x = x && __shfl_xor_sync(0xffffffff, x, offset, width); + for (int offset = width/2; offset > 0; offset >>= 1) { + x = __shfl_xor_sync(0xffffffff, x, offset, width) && x; + } + return x; + } +} + +template +static __device__ __forceinline__ int warp_reduce_any(int x) { + if (width == ggml_cuda_get_physical_warp_size()) { + return __any_sync(0xffffffff, x); + } else { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = __shfl_xor_sync(0xffffffff, x, offset, width) || x; + } + return x; } - return x; -#else - static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented"); - return __all_sync(0xffffffff, x); -#endif // GGML_USE_HIP } template @@ -490,13 +512,14 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP) } -#if CUDART_VERSION < CUDART_HMASK +#if (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || \ + (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK) static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b))); const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); return mask_low | mask_high; } -#endif // CUDART_VERSION < CUDART_HMASK +#endif // (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK) static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) { #if defined(GGML_USE_HIP) @@ -538,6 +561,45 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #endif // defined(GGML_USE_HIP) } +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) { + acc += v*u; +} + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) { + acc += v.x*u.x; + acc += v.y*u.y; +} + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) { +#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA)) + asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u)); +#else +#ifdef FAST_FP16_AVAILABLE + const float2 tmp = __half22float2(v*u); + acc += tmp.x + tmp.y; +#else + const float2 tmpv = __half22float2(v); + const float2 tmpu = __half22float2(u); + acc += tmpv.x * tmpu.x; + acc += tmpv.y * tmpu.y; +#endif // FAST_FP16_AVAILABLE +#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA)) +} + +// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD. +template +static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) { + if constexpr (nbytes == 4) { + *(int *) dst = *(const int *) src; + } else if constexpr (nbytes == 8) { + *(int2 *) dst = *(const int2 *) src; + } else if constexpr (nbytes == 16) { + *(int4 *) dst = *(const int4 *) src; + } else { + static_assert(nbytes == 0 && nbytes == -1, "bad nbytes"); + } +} + static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #if CUDART_VERSION >= 12080 const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x); @@ -556,7 +618,49 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #endif // CUDART_VERSION >= 12050 } -typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +static const uint3 init_fastdiv_values(uint32_t d) { + GGML_ASSERT(d != 0); + + // compute L = ceil(log2(d)); + uint32_t L = 0; + while (L < 32 && (uint32_t{ 1 } << L) < d) { + L++; + } + + uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); + // pack divisor as well to reduce error surface + return make_uint3(mp, L, d); +} + +static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain in + // fastdiv_values.z is unused and optimized away by the compiler. + // Compute high 32 bits of n * mp + const uint32_t hi = __umulhi(n, fastdiv_values.x); + // add n, apply bit shift + return (hi + n) >> fastdiv_values.y; +} + +static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain in (see init_fastdiv_values) + return n - fastdiv(n, fastdiv_values) * fastdiv_values.z; +} + +// Calculate both division and modulo at once, returns +static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain in (see init_fastdiv_values) + const uint32_t div_val = fastdiv(n, fastdiv_values); + const uint32_t mod_val = n - div_val * fastdiv_values.z; + return make_uint2(div_val, mod_val); +} + +typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v); static __device__ __forceinline__ float get_alibi_slope( const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1 diff --git a/ggml/src/ggml-cuda/conv-transpose-1d.cu b/ggml/src/ggml-cuda/conv-transpose-1d.cu index fe4caf674d4..8418ba66731 100644 --- a/ggml/src/ggml-cuda/conv-transpose-1d.cu +++ b/ggml/src/ggml-cuda/conv-transpose-1d.cu @@ -34,10 +34,7 @@ static __global__ void conv_transpose_1d_kernel( } } dst[global_index] = accumulator; - GGML_UNUSED(p0); GGML_UNUSED(d0); GGML_UNUSED(src0_ne3); - GGML_UNUSED(src1_ne3); GGML_UNUSED(dst_ne3); - GGML_UNUSED(src1_ne1); GGML_UNUSED(dst_ne1); - GGML_UNUSED(src1_ne2); GGML_UNUSED(dst_ne2); + GGML_UNUSED_VARS(p0, d0, src0_ne3, src1_ne3, dst_ne3, src1_ne1, dst_ne1, src1_ne2, dst_ne2); } static void conv_transpose_1d_f32_f32_cuda( diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu new file mode 100644 index 00000000000..142dd66903a --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -0,0 +1,166 @@ +#include "conv2d.cuh" +#include "convert.cuh" + +struct conv_params { + const int64_t IW, IH; + const int64_t OW, OH; + const int64_t KW, KH; + const int64_t ST_X, ST_Y; + const int64_t PD_X, PD_Y; + const int64_t DL_X, DL_Y; + const int64_t IC, OC; + const int64_t B; + const int64_t TOTAL; +}; + +struct kernel_bounds { + int64_t y_min, y_max; + int64_t x_min, x_max; +}; + +__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { + return (a > b) ? a : b; +} + +__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { + return (a < b) ? a : b; +} + +__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { + kernel_bounds bounds; + bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + return bounds; +} + +__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, + int64_t kern_coord, + int64_t stride, + int64_t dilation, + int64_t padding) { + return out_coord * stride + kern_coord * dilation - padding; +} + +struct whcn_layout { + __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; + } + + __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { + return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; + } + + __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; + } + + __device__ static void unpack_indices(int64_t global_idx, + const conv_params & P, + int64_t & n, + int64_t & c, + int64_t & out_y, + int64_t & out_x) { + out_x = global_idx % P.OW; + out_y = (global_idx / P.OW) % P.OH; + c = (global_idx / (P.OW * P.OH)) % P.OC; + n = global_idx / (P.OW * P.OH * P.OC); + } +}; + +template +static __global__ void conv2d_kernel(const float * __restrict__ input, + const T * __restrict__ kernel, + float * __restrict__ output, + const conv_params P) { + const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (global_idx >= P.TOTAL) { + return; + } + + int64_t n, c_out, out_y, out_x; + Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x); + + float acc = 0.0f; + + for (int64_t c_in = 0; c_in < P.IC; ++c_in) { + kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P); + + for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) { + const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y); + + for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) { + const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X); + + const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)]; + const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)]; + acc += (input_val * ggml_cuda_cast(kernel_val)); + } + } + } + + // [N, OC, OH, OW] + output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc; +} + +template +static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; + conv2d_kernel<<>>(X_D, K_D, Y_D, P); +} + +static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_cuda(X_D, K_D, Y_D, P, st); +} + +static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_cuda(X_D, K_D, Y_D, P, st); +} + +void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * input = dst->src[1]; + float * K_D = (float *) kernel->data; + const float * X_D = (const float *) input->data; + float * Y_D = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + + // same number of input channels + GGML_ASSERT(input->ne[2] == kernel->ne[2]); + + cudaStream_t st = ctx.stream(); + + const int32_t * p = (const int32_t *) dst->op_params; + const int ST_X = p[0]; // stride_x + const int ST_Y = p[1]; // stride_y + const int PD_X = p[2]; // padding_x + const int PD_Y = p[3]; // padding_y + const int DL_X = p[4]; // dilation_x + const int DL_Y = p[5]; // dilation_y + + // No cwhn + GGML_ASSERT(p[6] == false); + + const int IW = input->ne[0]; // input_w + const int IH = input->ne[1]; // input_h + const int OW = dst->ne[0]; // output_w + const int OH = dst->ne[1]; // output_h + const int KW = kernel->ne[0]; // kernel_w + const int KH = kernel->ne[1]; // kernel_h + const int IC = input->ne[2]; // input_channels + const int OC = kernel->ne[3]; // ouptut_chanles + const int B = input->ne[3]; // n_batches + + const int64_t total = B * OC * OH * OW; + conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + + if (kernel->type == GGML_TYPE_F16) { + conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st); + } else { + conv2d_cuda_f32(X_D, K_D, Y_D, params, st); + } +} diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh new file mode 100644 index 00000000000..ce4802c7ed7 --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -0,0 +1,5 @@ +#pragma once +#include "common.cuh" + +#define CUDA_CONV2D_BLOCK_SIZE 256 +void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 8f0efdcc126..ba3d4eeb880 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -27,7 +27,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ const int64_t y_offset = qr == 1 ? 1 : qk/2; // dequantize - dfloat2 v; + float2 v; dequantize_kernel(vx, ib, iqs, v); const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs; @@ -71,9 +71,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d)); } #else - GGML_UNUSED(vx); - GGML_UNUSED(y); - GGML_UNUSED(k); + GGML_UNUSED_VARS(vx, y, k); NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL } diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index c62e8a1b104..ef9e129950c 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -38,6 +38,8 @@ template return __float2bfloat16(float(x)); } else if constexpr(std::is_same_v) { return __bfloat162float(x); + } else if constexpr(std::is_same_v) { + return int32_t(x); } else { return float(x); } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index f9bb025643c..1b763a62898 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -42,7 +42,7 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { #pragma unroll for (int j = 0; j < QK8_0; j += 2) { - dfloat2 dq; + float2 dq; dequantize_q8_0(cxi, 0, j, dq); *(cdstf + j) = dq.x; *(cdstf + j + 1) = dq.y; @@ -55,7 +55,7 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { #pragma unroll for (int j = 0; j < qk/2; j++) { - dfloat2 dq; + float2 dq; dequant(cxi, 0, j, dq); *(cdstf + j) = dq.x; *(cdstf + j + qk/2) = dq.y; @@ -134,8 +134,7 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream)); cuda_graph->graph_cpynode_index = 0; // reset index #else - GGML_UNUSED(cuda_graph); GGML_UNUSED(host_dest_ptrs); - GGML_UNUSED(host_dest_ptrs_size); GGML_UNUSED(stream); + GGML_UNUSED_VARS(cuda_graph, host_dest_ptrs, host_dest_ptrs_size, stream); #endif } @@ -375,6 +374,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); @@ -438,6 +441,10 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index bd3c2d9db94..e060fb29fdc 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -1,48 +1,37 @@ #include "common.cuh" -static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; - const dfloat d = x[ib].d; + const float d = x[ib].d; const int vui = x[ib].qs[iqs]; v.x = vui & 0xF; v.y = vui >> 4; -#ifdef GGML_CUDA_F16 - v = __hsub2(v, {8.0f, 8.0f}); - v = __hmul2(v, {d, d}); -#else v.x = (v.x - 8.0f) * d; v.y = (v.y - 8.0f) * d; -#endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx; - const dfloat d = __low2half(x[ib].dm); - const dfloat m = __high2half(x[ib].dm); + const float2 dm = __half22float2(x[ib].dm); const int vui = x[ib].qs[iqs]; v.x = vui & 0xF; v.y = vui >> 4; -#ifdef GGML_CUDA_F16 - v = __hmul2(v, {d, d}); - v = __hadd2(v, {m, m}); -#else - v.x = (v.x * d) + m; - v.y = (v.y * d) + m; -#endif // GGML_CUDA_F16 + v.x = (v.x * dm.x) + dm.y; + v.y = (v.y * dm.x) + dm.y; } -static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q5_0 * x = (const block_q5_0 *) vx; - const dfloat d = x[ib].d; + const float d = x[ib].d; uint32_t qh; memcpy(&qh, x[ib].qh, sizeof(qh)); @@ -53,20 +42,14 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); v.y = ((x[ib].qs[iqs] >> 4) | xh_1); -#ifdef GGML_CUDA_F16 - v = __hsub2(v, {16.0f, 16.0f}); - v = __hmul2(v, {d, d}); -#else v.x = (v.x - 16.0f) * d; v.y = (v.y - 16.0f) * d; -#endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q5_1 * x = (const block_q5_1 *) vx; - const dfloat d = __low2half(x[ib].dm); - const dfloat m = __high2half(x[ib].dm); + const float2 dm = __half22float2(x[ib].dm); uint32_t qh; memcpy(&qh, x[ib].qh, sizeof(qh)); @@ -77,27 +60,18 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); v.y = ((x[ib].qs[iqs] >> 4) | xh_1); -#ifdef GGML_CUDA_F16 - v = __hmul2(v, {d, d}); - v = __hadd2(v, {m, m}); -#else - v.x = (v.x * d) + m; - v.y = (v.y * d) + m; -#endif // GGML_CUDA_F16 + v.x = (v.x * dm.x) + dm.y; + v.y = (v.y * dm.x) + dm.y; } -static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q8_0 * x = (const block_q8_0 *) vx; - const dfloat d = x[ib].d; + const float d = x[ib].d; v.x = x[ib].qs[iqs + 0]; v.y = x[ib].qs[iqs + 1]; -#ifdef GGML_CUDA_F16 - v = __hmul2(v, {d, d}); -#else v.x *= d; v.y *= d; -#endif // GGML_CUDA_F16 } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index d4ed938391b..142a3a88d1d 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -647,9 +647,7 @@ static __global__ void flash_attn_stream_k_fixup( } template // D == head size -#if !defined(GGML_USE_HIP) __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIP) static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const float2 * __restrict__ VKQ_meta, @@ -692,10 +690,7 @@ static __global__ void flash_attn_combine_results( float VKQ_numerator = 0.0f; float VKQ_denominator = 0.0f; for (int l = 0; l < parallel_blocks; ++l) { - const float diff = meta[l].x - kqmax; - float KQ_max_scale = expf(diff); - const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); - *((uint32_t *) &KQ_max_scale) &= ftz_mask; + const float KQ_max_scale = expf(meta[l].x - kqmax); VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; @@ -704,28 +699,6 @@ static __global__ void flash_attn_combine_results( dst[tid] = VKQ_numerator / VKQ_denominator; } -[[noreturn]] -static void on_no_fattn_vec_case(const int D) { - if (D == 64) { - fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); - fprintf(stderr, "By default only f16 KV cache is supported.\n"); - fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); - GGML_ABORT("fatal error"); - } else if (D == 128) { - fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); - fprintf(stderr, "Supported combinations:\n"); - fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n"); - fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n"); - fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n"); - fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); - GGML_ABORT("fatal error"); - } else { - fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D); - fprintf(stderr, "Only f16 is supported.\n"); - GGML_ABORT("fatal error"); - } -} - template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, @@ -858,11 +831,10 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); } - int parallel_blocks = 1; - const dim3 block_dim(warp_size, nwarps, 1); int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + int parallel_blocks = max_blocks_per_sm; dim3 blocks_num; if (stream_k) { @@ -884,9 +856,6 @@ void launch_fattn( GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. - // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: - parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); - // parallel_blocks must not be larger than what the tensor size allows: parallel_blocks = std::min(parallel_blocks, ntiles_KQ); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 39731baaeb7..57defb0c629 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -767,14 +767,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } #else - GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); - GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); - GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); - GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K); - GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); - GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); - GGML_UNUSED(kb0); GGML_UNUSED(tile_Q); + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, + scale, slope, logit_softcap, ne01, ne02, + stride_K, stride_V, stride_mask, + tile_Q, tile_K, tile_V, tile_mask, + Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -1236,12 +1233,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } #else - GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); - GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); - GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); - GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); - GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dstk_fixup, + scale, slope, logit_softcap, ne01, ne02, + stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, + jt, kb0_start, kb0_stop); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -1395,17 +1390,15 @@ static __global__ void flash_attn_ext_f16( (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); - GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); - GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu deleted file mode 100644 index 1e23f8f79c2..00000000000 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ /dev/null @@ -1,373 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" -#include "fattn-tile-f16.cuh" - -#define FATTN_KQ_STRIDE_TILE_F16 64 - -template // D == head size -#if !defined(GGML_USE_HIP) -__launch_bounds__(nwarps*WARP_SIZE, 2) -#endif // !defined(GGML_USE_HIP) -static __global__ void flash_attn_tile_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) - - // Skip unused kernel variants for faster compilation: -#ifdef FP16_MMA_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FP16_MMA_AVAILABLE - if (use_logit_softcap && !(D == 128 || D == 256)) { - NO_DEVICE_CODE; - return; - } - - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - const half slopeh = __float2half(slopef); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - - __shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16]; - half2 * KQ2 = (half2 *) KQ; - - __shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts. - - half kqmax[ncols/nwarps]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - kqmax[j0/nwarps] = -HALF_MAX_HALF; - } - half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}}; - - half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; - - // Convert Q to half2 and store in registers: - __shared__ half2 Q_h2[ncols][D/2]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); - Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); - } - } - - __syncthreads(); - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) { - // Calculate KQ tile and keep track of new maximum KQ values: - - half kqmax_new[ncols/nwarps]; -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - kqmax_new[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; - } - } - - __syncthreads(); - - half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}}; - -#pragma unroll - for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) { - half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE]; - half2 Q_k[ncols/nwarps]; - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - - K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ]; - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps]; - } - } - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - half sum; - if (use_logit_softcap) { - const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - sum = logit_softcap * tanhf(tmp.x + tmp.y); - } else { - sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - } - sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); - - kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum); - - KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum; - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); - const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps])); - kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]); - const half2 val = h2exp(diff); - kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val; - KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) { - const int k = k0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) { - half2 V_k[(D/2)/WARP_SIZE][2]; - half2 KQ_k[ncols/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i]; - V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i]; - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]); - VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]); - } - } - } - - __syncthreads(); - } - - //Attention sink: adjust running max and sum once per head - if (sinksf && blockIdx.y == 0) { - const half sink = __float2half(sinksf[head]); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j)); - kqmax[j0/nwarps] = kqmax_new_j; - - const half val = hexp(sink - kqmax[j0/nwarps]); - kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; - if (threadIdx.x == 0) { - kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val); - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale; - } - } - } - - float2 * dst2 = (float2 *) dst; - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { - const int j_VKQ = j_VKQ_0 + threadIdx.y; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); - kqsum_j = warp_reduce_sum((float)kqsum_j); - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#pragma unroll - for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { - const int i0 = i00 + threadIdx.x; - - half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; - if (gridDim.y == 1) { - dst_val /= __half2half2(kqsum_j); - } - dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val); - } - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); - } - } -#else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); - GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); - GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); - NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) -} - -template -void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); - } break; - default: { - GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } -} - -void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - const int32_t precision = KQV->op_params[3]; - GGML_ASSERT(precision == GGML_PREC_DEFAULT); - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (Q->ne[1] <= 16) { - constexpr int cols_per_block = 16; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f16_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f16_64_128(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 32; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f16_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f16_64_128(ctx, dst); - } -} diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cuh b/ggml/src/ggml-cuda/fattn-tile-f16.cuh deleted file mode 100644 index ffc5878427b..00000000000 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cuh +++ /dev/null @@ -1,3 +0,0 @@ -#include "common.cuh" - -void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu deleted file mode 100644 index c58194937d7..00000000000 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ /dev/null @@ -1,383 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" -#include "fattn-tile-f32.cuh" - -#define FATTN_KQ_STRIDE_TILE_F32 32 - -template // D == head size -#if !defined(GGML_USE_HIP) -__launch_bounds__(nwarps*WARP_SIZE, 2) -#endif // !defined(GGML_USE_HIP) -static __global__ void flash_attn_tile_ext_f32( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#ifdef FLASH_ATTN_AVAILABLE - - // Skip unused kernel variants for faster compilation: -#ifdef FP16_MMA_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FP16_MMA_AVAILABLE - if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); - GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); - GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); - NO_DEVICE_CODE; - return; - } - - // In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - - __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32]; - - __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts. - float2 * KV_tmp2 = (float2 *) KV_tmp; - - float kqmax[ncols/nwarps]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - kqmax[j0/nwarps] = -FLT_MAX/2.0f; - } - float kqsum[ncols/nwarps] = {0.0f}; - - float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; - - // Convert Q to half2 and store in registers: - __shared__ float Q_f[ncols][D]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { - float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f); - Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale; - Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale; - } - } - - __syncthreads(); - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) { - // Calculate KQ tile and keep track of new maximum KQ values: - - float kqmax_new[ncols/nwarps]; -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - kqmax_new[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { - const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; - KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp); - KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); - } - } - - __syncthreads(); - - float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}}; - -#pragma unroll - for (int k_KQ = 0; k_KQ < D; ++k_KQ) { - float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE]; - float Q_k[ncols/nwarps]; - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - - K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ]; - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps]; - } - } - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - if (use_logit_softcap) { - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - } - - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; - - kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - - KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]; - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); - const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]); - kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; - - float kqsum_add = 0.0f; -#pragma unroll - for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps]; - const float val = expf(diff); - kqsum_add += val; - KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val; - } - kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; - VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) { - const int k = k0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; - KV_tmp2[k*(D/2) + i].x = __low2float(tmp); - KV_tmp2[k*(D/2) + i].y = __high2float(tmp); - } - } - - __syncthreads(); - -#pragma unroll - for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) { - float2 V_k[(D/2)/WARP_SIZE]; - float KQ_k[ncols/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i]; - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps]; - VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps]; - } - } - } - - __syncthreads(); - } - - - //Attention sink: adjust running max and sum once per head - if (sinksf && blockIdx.y == 0) { - const float sink = sinksf[head]; - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j); - kqmax[j0/nwarps] = kqmax_new_j; - - const float val = expf(sink - kqmax[j0/nwarps]); - kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; - if (threadIdx.x == 0) { - kqsum[j0/nwarps] += val; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; - VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; - } - } - } - - float2 * dst2 = (float2 *) dst; - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { - const int j_VKQ = j_VKQ_0 + threadIdx.y; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - float kqsum_j = kqsum[j_VKQ_0/nwarps]; - kqsum_j = warp_reduce_sum(kqsum_j); - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#pragma unroll - for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { - const int i0 = i00 + threadIdx.x; - - float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; - if (gridDim.y == 1) { - dst_val.x /= kqsum_j; - dst_val.y /= kqsum_j; - } - dst2[j_dst_unrolled*(D/2) + i0] = dst_val; - } - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); - } - } -#else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); - GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); - GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); - NO_DEVICE_CODE; -#endif // FLASH_ATTN_AVAILABLE -} - -template -void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); - } break; - default: { - GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } -} - -void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (Q->ne[1] <= 16) { - constexpr int cols_per_block = 16; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f32_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f32_64_128(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 32; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f32_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f32_64_128(ctx, dst); - } -} diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cuh b/ggml/src/ggml-cuda/fattn-tile-f32.cuh deleted file mode 100644 index b1c546c8054..00000000000 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cuh +++ /dev/null @@ -1,3 +0,0 @@ -#include "common.cuh" - -void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu new file mode 100644 index 00000000000..131a5099a3b --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -0,0 +1,755 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-tile.cuh" + +// kq_stride == number of KQ rows to process per iteration +// kq_nbatch == number of K columns to load in parallel for KQ calculation + +static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) { + if (GGML_CUDA_CC_IS_AMD(cc)) { + if (GGML_CUDA_CC_IS_RDNA(cc)) { + switch (D) { + case 64: + return 128; + case 128: + case 256: + return ncols <= 16 ? 128 : 64; + default: + GGML_ABORT("fatal error"); + return -1; + } + } + switch (D) { + case 64: + return ncols == 32 ? 128 : 64; + case 128: + return ncols == 32 ? 64 : 32; + case 256: + return 32; + default: + GGML_ABORT("fatal error"); + return -1; + } + } + if (fast_fp16_available(cc)) { + switch (D) { + case 64: + case 128: + case 256: + return ncols <= 16 ? 128 : 64; + default: + GGML_ABORT("fatal error"); + return -1; + } + } + switch (D) { + case 64: + return ncols <= 16 ? 128 : 64; + case 128: + return ncols <= 16 ? 64 : 32; + case 256: + return 32; + default: + GGML_ABORT("fatal error"); + return -1; + } + GGML_UNUSED(warp_size); +} + +static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) { +#ifdef GGML_USE_HIP +#ifdef RDNA + switch (D) { + case 64: + return 128; + case 128: + case 256: + return ncols <= 16 ? 128 : 64; + default: + return -1; + } +#else + switch (D) { + case 64: + return ncols == 32 ? 128 : 64; + case 128: + return ncols == 32 ? 64 : 32; + case 256: + return 32; + default: + return -1; + } +#endif // RDNA +#else +#ifdef FAST_FP16_AVAILABLE + switch (D) { + case 64: + case 128: + case 256: + return ncols <= 16 ? 128 : 64; + default: + return -1; + } +#else + switch (D) { + case 64: + return ncols <= 16 ? 128 : 64; + case 128: + return ncols <= 16 ? 64 : 32; + case 256: + return 32; + default: + return -1; + } +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP + GGML_UNUSED_VARS(ncols, warp_size); +} + +static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) { +#ifdef GGML_USE_HIP + switch (D) { + case 64: + return 64; + case 128: + case 256: + return 128; + default: + return -1; + } +#else +#ifdef FAST_FP16_AVAILABLE + switch (D) { + case 64: + return 64; + case 128: + case 256: + return 128; + default: + return -1; + } +#else + switch (D) { + case 64: + return 64; + case 128: + return 128; + case 256: + return ncols <= 16 ? 128 : 64; + default: + return -1; + } +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP + GGML_UNUSED_VARS(ncols, warp_size); +} + +static int fattn_tile_get_nthreads_host(const int cc, const int ncols) { + return 256; + GGML_UNUSED_VARS(cc, ncols); +} + +static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) { + return 256; + GGML_UNUSED(ncols); +} + +static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) { +#ifdef RDNA + return 3; +#else + return ncols <= 16 ? 3 : 2; +#endif // RDNA + GGML_UNUSED(ncols); +} + +template // D == head size +__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols)) +static __global__ void flash_attn_tile( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#ifdef FLASH_ATTN_AVAILABLE + + // Skip unused kernel variants for faster compilation: +#ifdef FP16_MMA_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FP16_MMA_AVAILABLE + + if (use_logit_softcap && !(D == 128 || D == 256)) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; + return; + } + + constexpr int warp_size = 32; + constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size; + constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size); + static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size."); + constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size); + static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch"); + + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const float * sinksf = (const float *) (sinks); + + const int stride_KV2 = nb11 / sizeof(half2); + + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int cpw = ncols/nwarps; // cols per warp + + // softmax_iter_j == number of KQ columns for which to calculate softmax in parallel. + // KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes. +#ifdef FAST_FP16_AVAILABLE + constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; + + __shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; + __shared__ half2 Q_tmp[ncols][D/2]; + __shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts. + half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; +#else + constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; + + __shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; + __shared__ float Q_tmp[ncols][D]; + __shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts. + float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; +#endif // FAST_FP16_AVAILABLE + static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j"); + + float KQ_max[cpw]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + KQ_max[j0/nwarps] = -FLT_MAX/2.0f; + } + float KQ_sum[cpw] = {0.0f}; + + // Load Q data, convert to FP16 if fast. +#pragma unroll + for (int j0 = 0; j0 < cpw; ++j0) { + const int j = j0 + threadIdx.y*cpw; + + constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { + float tmp_f[cpy_ne_D] = {0.0f}; + if (ic0 + j < ne01) { + ggml_cuda_memcpy_1(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]); + } + +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp_f[i1] *= scale; + } + +#ifdef FAST_FP16_AVAILABLE + half2 tmp_h2[cpy_ne_D/2]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { + tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); + } + ggml_cuda_memcpy_1(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2); +#else + ggml_cuda_memcpy_1 (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f); +#endif // FAST_FP16_AVAILABLE + } + } + + __syncthreads(); + + // Main loop over KV cache: + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) { + // Calculate KQ tile and keep track of new maximum KQ values: + + float KQ_max_new[cpw]; +#pragma unroll + for (int j = 0; j < cpw; ++j) { + KQ_max_new[j] = KQ_max[j]; + } + + float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication. + + // KQ = K @ Q matrix multiplication: +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) { +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) { + ggml_cuda_memcpy_1( + &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], + &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]); + } +#else + constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size; +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) { + half2 tmp_h2[cpy_ne_kqnb/2]; + ggml_cuda_memcpy_1( + tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]); + + float2 tmp_f2[cpy_ne_kqnb/2]; +#pragma unroll + for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) { + tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]); + } + ggml_cuda_memcpy_1( + &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2); + } +#endif // FAST_FP16_AVAILABLE + } + + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) { + half2 K_k[kq_stride/warp_size][cpy_ne]; + half2 Q_k[cpw][cpy_ne]; +#else +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) { + float K_k[kq_stride/warp_size][cpy_ne]; + float Q_k[cpw][cpy_ne]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { + const int i_KQ = i_KQ_0 + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } +#pragma unroll + for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { + const int j_KQ = j_KQ_0 + threadIdx.y*cpw; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { +#pragma unroll + for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { +#pragma unroll + for (int k = 0; k < cpy_ne; ++k) { + ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]); + } + } + } + } + + if (k_KQ_0 + kq_nbatch < D) { + __syncthreads(); // Sync not needed on last iteration. + } + } + + // Apply logit softcap, mask, update KQ_max: +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { + const int i_KQ = i_KQ_0 + threadIdx.x; + +#pragma unroll + for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { + const int j_KQ = j_KQ_0 + threadIdx.y*cpw; + + if (use_logit_softcap) { + KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]); + } + + KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; + + KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]); + } + } + + __syncthreads(); + + // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: +#pragma unroll + for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { +#ifdef FAST_FP16_AVAILABLE + half tmp[kq_stride/warp_size][softmax_iter_j]; +#else + float tmp[kq_stride/warp_size][softmax_iter_j]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int j1 = 0; j1 < softmax_iter_j; ++j1) { + KQ_max_new[j0+j1] = warp_reduce_max(KQ_max_new[j0+j1]); + const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]); + KQ_max[j0+j1] = KQ_max_new[j0+j1]; + + float KQ_sum_add = 0.0f; +#pragma unroll + for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { + const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]); + KQ_sum_add += val; + tmp[i0/warp_size][j1] = val; + } + KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale; + VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + ggml_cuda_memcpy_1( + KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]); + } + } + + // VKQ = V @ KQ matrix multiplication: + constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K. + static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter"); +#pragma unroll + for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) { +#pragma unroll + for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) { + const int k_tile = k1 + threadIdx.y; + +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1( + &KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D], + &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]); + } +#else + constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { + half2 tmp_h2[cpy_ne_D/2]; + ggml_cuda_memcpy_1( + tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]); + + float2 tmp_f2[cpy_ne_D/2]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { + tmp_f2[i1] = __half22float2(tmp_h2[i1]); + } + ggml_cuda_memcpy_1( + &KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2); + } +#endif // FAST_FP16_AVAILABLE + } + + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE +#pragma unroll + for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { + half2 V_k[(D/2)/warp_size]; + half2 KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { + const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); + + half tmp[softmax_iter_j]; + ggml_cuda_memcpy_1( + &tmp, KQ[j][k0 + k1]); +#pragma unroll + for (int j1 = 0; j1 < softmax_iter_j; ++j1) { + KQ_k[j0+j1] = __half2half2(tmp[j1]); + } + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { +#pragma unroll + for (int j0 = 0; j0 < cpw; ++j0) { + VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0]; + } + } + } +#else +#pragma unroll + for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { + float2 V_k[(D/2)/warp_size]; + float KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { + const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); + + ggml_cuda_memcpy_1( + &KQ_k[j0], KQ[j][k0 + k1]); + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { +#pragma unroll + for (int j0 = 0; j0 < cpw; ++j0) { + VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0]; + VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0]; + } + } + } +#endif // FAST_FP16_AVAILABLE + + __syncthreads(); + } + } + + + // Attention sink: adjust running max and sum once per head + if (sinksf && blockIdx.y == 0) { + const float sink = sinksf[head]; + +#pragma unroll + for (int j0 = 0; j0 < cpw; ++j0) { + float KQ_max_new_j = fmaxf(KQ_max[j0], sink); + KQ_max_new_j = warp_reduce_max(KQ_max_new_j); + + const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j); + KQ_max[j0] = KQ_max_new_j; + + const float val = expf(sink - KQ_max[j0]); + KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale; + if (threadIdx.x == 0) { + KQ_sum[j0] += val; + } + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0][i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0][i0/warp_size].x *= KQ_max_scale; + VKQ[j0][i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + } + +#pragma unroll + for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { + KQ_sum[j_VKQ_0] = warp_reduce_sum(KQ_sum[j_VKQ_0]); + } + if (gridDim.y == 1) { +#pragma unroll + for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]); +#pragma unroll + for (int i = 0; i < (D/2)/warp_size; ++i) { + VKQ[j_VKQ_0][i] *= KQ_sum_j_inv; + } +#else + const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0]; +#pragma unroll + for (int i = 0; i < (D/2)/warp_size; ++i) { + VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv; + VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv; + } +#endif // FAST_FP16_AVAILABLE + } + } + + // Write back results: +#pragma unroll + for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { + const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw; + + if (ic0 + j_VKQ >= ne01) { + return; + } + + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { + float2 tmp[cpy_ne_D]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]); + } + ggml_cuda_memcpy_1(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); + } +#else + constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1( + &dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]); + } +#endif // FAST_FP16_AVAILABLE + + if (gridDim.y != 1 && threadIdx.x == 0) { + dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]); + } + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE +} + +template +static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = 32; + + constexpr size_t nbytes_shared = 0; + +#ifdef GGML_USE_HIP + if constexpr (D <= 128) { + if (Q->ne[1] > 32) { + constexpr int cols_per_block = 64; + const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; + fattn_kernel_t fattn_kernel = flash_attn_tile; + const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); + return; + } + } +#endif // GGML_USE_HIP + + if (Q->ne[1] > 16) { + constexpr int cols_per_block = 32; + const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; + fattn_kernel_t fattn_kernel = flash_attn_tile; + const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); + return; + } + + constexpr int cols_per_block = 16; + const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; + fattn_kernel_t fattn_kernel = flash_attn_tile; + const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); +} + +template +static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + switch (Q->ne[0]) { + case 64: { + launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst); + } break; + case 128: { + launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst); + } break; + case 256: { + launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst); + } break; + default: { + GGML_ABORT("Unsupported head size"); + } break; + } +} + +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_switch_head_size(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_switch_head_size(ctx, dst); + } +} diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh new file mode 100644 index 00000000000..10dc22d1bf9 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index b05f682cd3b..27a2dd6ae44 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -349,17 +349,15 @@ static __global__ void flash_attn_vec_ext_f16( dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); - GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); - GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index d6d0bfb744b..da195d0334d 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -37,17 +37,15 @@ static __global__ void flash_attn_vec_ext_f32( // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); - GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); - GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); NO_DEVICE_CODE; return; } @@ -345,17 +343,15 @@ static __global__ void flash_attn_vec_ext_f32( dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); - GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); - GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 6bc7943ccd5..2219191fd91 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -471,16 +471,15 @@ static __global__ void flash_attn_ext_f16( dst_meta[j_dst_unrolled] = dst_meta_val; } #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31); - GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); - GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 22e90d0e7b3..7626d89ca08 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -1,8 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" #include "fattn-mma-f16.cuh" -#include "fattn-tile-f16.cuh" -#include "fattn-tile-f32.cuh" +#include "fattn-tile.cuh" #include "fattn-vec-f16.cuh" #include "fattn-vec-f32.cuh" #include "fattn-wmma-f16.cuh" @@ -190,7 +189,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) #endif // GGML_CUDA_FA_ALL_QUANTS - on_no_fattn_vec_case(Q->ne[0]); + GGML_ABORT("fatal error"); } #define FATTN_VEC_F32_CASE(D, type_K, type_V) \ @@ -265,74 +264,177 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) #endif // GGML_CUDA_FA_ALL_QUANTS - on_no_fattn_vec_case(Q->ne[0]); + GGML_ABORT("fatal error"); } -void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +// Best FlashAttention kernel for a specific GPU: +enum best_fattn_kernel { + BEST_FATTN_KERNEL_NONE = 0, + BEST_FATTN_KERNEL_TILE = 200, + BEST_FATTN_KERNEL_VEC_F32 = 100, + BEST_FATTN_KERNEL_VEC_F16 = 110, + BEST_FATTN_KERNEL_WMMA_F16 = 300, + BEST_FATTN_KERNEL_MMA_F16 = 400, +}; + +static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) { +#ifndef FLASH_ATTN_AVAILABLE + GGML_UNUSED(device); GGML_UNUSED(dst); + return BEST_FATTN_KERNEL_NONE; +#endif// FLASH_ATTN_AVAILABLE + const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; const ggml_tensor * mask = dst->src[3]; - ggml_cuda_set_device(ctx.device); - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + + const int cc = ggml_cuda_info().devices[device].cc; + const int warp_size = ggml_cuda_info().devices[device].warp_size; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); -#if defined(GGML_HIP_ROCWMMA_FATTN) - if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); - return; + switch (K->ne[0]) { + case 64: + case 128: + case 256: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } + break; + case 80: + case 96: + case 112: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } + if (!fp16_mma_available(cc) && !turing_mma_available(cc)) { + return BEST_FATTN_KERNEL_NONE; + } + break; + case 576: + if (V->ne[0] != 512) { + return BEST_FATTN_KERNEL_NONE; + } + if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; + default: + return BEST_FATTN_KERNEL_NONE; } -#endif // defined(GGML_HIP_ROCWMMA_FATTN) - if (!fast_fp16_available(cc)) { - if (Q->ne[1] <= 8 || Q->ne[0] == 256) { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); - } - return; +#ifndef GGML_CUDA_FA_ALL_QUANTS + if (K->type != V->type) { + return BEST_FATTN_KERNEL_NONE; } +#endif // GGML_CUDA_FA_ALL_QUANTS - if (!fp16_mma_available(cc)) { - if (prec == GGML_PREC_DEFAULT) { - if (Q->ne[1] <= 8 || Q->ne[0] == 256) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + switch (K->type) { + case GGML_TYPE_F16: + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: +#ifndef GGML_CUDA_FA_ALL_QUANTS + return BEST_FATTN_KERNEL_NONE; +#endif // GGML_CUDA_FA_ALL_QUANTS + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: +#ifdef GGML_CUDA_FA_ALL_QUANTS + if (K->ne[0] != 128 && K->ne[0] != 64) { + return BEST_FATTN_KERNEL_NONE; } - } else { - if (Q->ne[1] <= 8 || Q->ne[0] == 256) { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); +#else + if (K->ne[0] != 128) { + return BEST_FATTN_KERNEL_NONE; } - } - return; +#endif // GGML_CUDA_FA_ALL_QUANTS + break; + default: + return BEST_FATTN_KERNEL_NONE; + } + + switch (V->type) { + case GGML_TYPE_F16: + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + if (K->ne[0] != 128) { + return BEST_FATTN_KERNEL_NONE; + } + break; + default: + return BEST_FATTN_KERNEL_NONE; + } + + if (mask && mask->ne[2] != 1) { + return BEST_FATTN_KERNEL_NONE; } - const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations - const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; - const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192); - const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion && - (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000); const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; - if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { - if (prec == GGML_PREC_DEFAULT) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + + // If Turing tensor cores available, use them except for some cases with batch size 1: + if (turing_mma_available(cc)) { + const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask; // The mma-based kernels have GQA-specific optimizations + const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; + const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (gqa_ratio > 4 && K->ne[1] >= 8192); + const bool mma_faster_for_bs1 = gqa_opt_applies && !mma_needs_data_conversion && + (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000); + if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { + if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { + return BEST_FATTN_KERNEL_VEC_F16; + } + return BEST_FATTN_KERNEL_VEC_F32; } - return; + return BEST_FATTN_KERNEL_MMA_F16; } - // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: - if (fp16_mma_available(cc) && !turing_mma_available(cc)) { - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); - return; + // Use kernels specializes for small batch sizes if possible: + if (Q->ne[1] <= 8 && can_use_vector_kernel) { + if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { + return BEST_FATTN_KERNEL_VEC_F16; + } + return BEST_FATTN_KERNEL_VEC_F32; + } + + // For large batch sizes, use the WMMA kernel if possible: + if (fp16_mma_available(cc)) { + return BEST_FATTN_KERNEL_WMMA_F16; + } + + // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes: + return BEST_FATTN_KERNEL_TILE; +} + +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_set_device(ctx.device); + switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { + case BEST_FATTN_KERNEL_NONE: + GGML_ABORT("fatal error"); + case BEST_FATTN_KERNEL_TILE: + ggml_cuda_flash_attn_ext_tile(ctx, dst); + break; + case BEST_FATTN_KERNEL_VEC_F32: + ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + break; + case BEST_FATTN_KERNEL_VEC_F16: + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + break; + case BEST_FATTN_KERNEL_WMMA_F16: + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + break; + case BEST_FATTN_KERNEL_MMA_F16: + ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); + break; } +} - ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); +bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) { + return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE; } diff --git a/ggml/src/ggml-cuda/fattn.cuh b/ggml/src/ggml-cuda/fattn.cuh index ad3ca7a8d8e..78705d59951 100644 --- a/ggml/src/ggml-cuda/fattn.cuh +++ b/ggml/src/ggml-cuda/fattn.cuh @@ -1,3 +1,5 @@ #include "common.cuh" void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 68d3254fbe4..2fab33243dd 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -6,64 +6,66 @@ template static __global__ void k_get_rows( const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. - const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2; - const int i10 = blockIdx.x; - const int i11 = blockIdx.z / ne12; - const int i12 = blockIdx.z % ne12; + for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i10 = blockIdx.x; + const int i11 = z / ne12; // TODO fastdiv + const int i12 = z % ne12; - if (i00 >= ne00) { - return; - } - - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; - const int ib = i00/qk; // block index - const int iqs = (i00%qk)/qr; // quant index - const int iybs = i00 - i00%qk; // dst block start index - const int y_offset = qr == 1 ? 1 : qk/2; + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; - // dequantize - dfloat2 v; - dequantize_kernel(src0_row, ib, iqs, v); + // dequantize + float2 v; + dequantize_kernel(src0_row, ib, iqs, v); - dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x); - dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y); + dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x); + dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y); + } + } } template static __global__ void k_get_rows_float( const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. - const int i00 = blockIdx.y * blockDim.x + threadIdx.x; - const int i10 = blockIdx.x; - const int i11 = blockIdx.z / ne12; - const int i12 = blockIdx.z % ne12; + for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i10 = blockIdx.x; + const int i11 = z / ne12; // TODO fastdiv + const int i12 = z % ne12; - if (i00 >= ne00) { - return; - } + if (i00 >= ne00) { + return; + } - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); - dst_row[i00] = ggml_cuda_cast(src0_row[i00]); + dst_row[i00] = ggml_cuda_cast(src0_row[i00]); + } + } } template @@ -98,7 +100,7 @@ static void get_rows_cuda_q( cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); - const dim3 block_nums(ne10, block_num_y, ne11*ne12); + const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX)); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -116,7 +118,7 @@ static void get_rows_cuda_q( k_get_rows<<>>( src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); @@ -131,7 +133,7 @@ static void get_rows_cuda_float( cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; - const dim3 block_nums(ne10, block_num_y, ne11*ne12); + const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX)); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -147,7 +149,7 @@ static void get_rows_cuda_float( k_get_rows_float<<>>( src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index d6402a8daac..f3ba20fe3f7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -12,6 +12,7 @@ #include "ggml-cuda/clamp.cuh" #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" +#include "ggml-cuda/conv2d.cuh" #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/conv2d-transpose.cuh" #include "ggml-cuda/convert.cuh" @@ -49,6 +50,7 @@ #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" #include "ggml-cuda/set-rows.cuh" +#include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml.h" #include @@ -203,6 +205,8 @@ static ggml_cuda_device_info ggml_cuda_init() { GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__); #endif // GGML_CUDA_FORCE_CUBLAS GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); + + std::vector> turing_devices_without_mma; for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; @@ -260,7 +264,25 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc = 100*prop.major + 10*prop.minor; GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); -#endif // defined(GGML_USE_HIP) + std::string device_name(prop.name); + if (device_name == "NVIDIA GeForce MX450") { + turing_devices_without_mma.push_back({ id, device_name }); + } else if (device_name == "NVIDIA GeForce MX550") { + turing_devices_without_mma.push_back({ id, device_name }); + } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") { + turing_devices_without_mma.push_back({ id, device_name }); + } +#endif // defined(GGML_USE_HIP) + } + + if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) { + GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n"); + for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) { + GGML_LOG_INFO( + " Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str()); + } + GGML_LOG_INFO( + "Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n"); } for (int id = 0; id < info.device_count; ++id) { @@ -1328,9 +1350,7 @@ static void ggml_cuda_op_mul_mat_cublas( &beta, dst_dd_i, ldc)); } - GGML_UNUSED(dst); - GGML_UNUSED(src1_ddq_i); - GGML_UNUSED(src1_padded_row_size); + GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size); } static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { @@ -2089,6 +2109,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst); return; } + + if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) { + ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst); + return; + } } cudaStream_t stream = ctx.stream(); @@ -2354,6 +2379,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_PAD: ggml_cuda_op_pad(ctx, dst); break; + case GGML_OP_PAD_REFLECT_1D: + ggml_cuda_op_pad_reflect_1d(ctx, dst); + break; case GGML_OP_ARANGE: ggml_cuda_op_arange(ctx, dst); break; @@ -2429,6 +2457,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_IM2COL: ggml_cuda_op_im2col(ctx, dst); break; + case GGML_OP_IM2COL_3D: + ggml_cuda_op_im2col_3d(ctx, dst); + break; + case GGML_OP_CONV_2D: + ggml_cuda_op_conv2d(ctx, dst); + break; case GGML_OP_CONV_2D_DW: ggml_cuda_op_conv2d_dw(ctx, dst); break; @@ -2795,9 +2829,14 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } - if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + const ggml_tensor *add = nullptr; + + if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) { + add = cgraph->nodes[node_idx+2]; + } GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); @@ -2809,6 +2848,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } + if (add && (add->src[0]->type != GGML_TYPE_F32 || + add->src[1]->type != GGML_TYPE_F32 || + add->type != GGML_TYPE_F32) ) { + return false; + } + //if rms norm is the B operand, then we don't handle broadcast if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { return false; @@ -2819,6 +2864,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } + if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) { + return false; + } + return true; } @@ -2865,7 +2914,46 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) { + + if (node->op == GGML_OP_ADD) { + int n_fuse = 0; + ggml_op ops[8]; + std::fill(ops, ops + 8, GGML_OP_ADD); + + for (; n_fuse <= 6; ++n_fuse){ + if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { + break; + } + if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { + break; + } + if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { + break; + } + } + + n_fuse++; + + if (n_fuse > 1) { + for (int j = 0; j < n_fuse - 1; ++j) { + node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + } + cgraph->nodes[i + n_fuse - 1]->data = node->data; + ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse); + i += n_fuse - 1; + + continue; + } + } + + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { + ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) { ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); i++; continue; @@ -3052,6 +3140,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, + /* .graph_optimize = */ NULL, }; static ggml_guid_t ggml_backend_cuda_guid() { @@ -3084,7 +3173,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { return false; } -#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP) cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error @@ -3121,6 +3210,7 @@ struct ggml_backend_cuda_device_context { int device; std::string name; std::string description; + std::string pci_bus_id; }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { @@ -3145,9 +3235,12 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend } static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + props->name = ggml_backend_cuda_device_get_name(dev); props->description = ggml_backend_cuda_device_get_description(dev); props->type = ggml_backend_cuda_device_get_type(dev); + props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; @@ -3378,6 +3471,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { return true; } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) { + return true; + } + if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) { return true; } @@ -3479,19 +3578,22 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]); } case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: case GGML_OP_ARGSORT: case GGML_OP_ACC: return true; + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: + case GGML_OP_PAD: return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: - case GGML_OP_PAD: + case GGML_OP_PAD_REFLECT_1D: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: @@ -3499,44 +3601,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: return true; - case GGML_OP_FLASH_ATTN_EXT: { -#ifndef FLASH_ATTN_AVAILABLE - return false; -#endif // FLASH_ATTN_AVAILABLE - if (op->src[1]->ne[0] != op->src[2]->ne[0]) { - const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; - if (!turing_mma_available(cc)) { - return false; - } - const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; - return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0; - } - // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS] - if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) - && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { - return false; - } - if (op->src[0]->ne[0] == 192) { - return false; - } - if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { - return false; - } - if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { - return true; - } - if (op->src[0]->ne[0] == 128) { - return true; - } - if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { - return true; - } - if (op->src[3] && op->src[3]->ne[2] != 1) { - return false; - } - return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) && - op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; - } + case GGML_OP_FLASH_ATTN_EXT: + return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: @@ -3672,10 +3738,6 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t features.push_back({ "NO_PEER_COPY", "1" }); #endif - #ifdef GGML_CUDA_F16 - features.push_back({ "F16", "1" }); - #endif - #ifdef GGML_CUDA_USE_GRAPHS features.push_back({ "USE_GRAPHS", "1" }); #endif @@ -3746,6 +3808,10 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; + char pci_bus_id[16] = {}; + snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); + dev_ctx->pci_bus_id = pci_bus_id; + ggml_backend_dev_t dev = new ggml_backend_device { /* .iface = */ ggml_backend_cuda_device_interface, /* .reg = */ ®, diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 16bb9bec97d..56dc0545742 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -112,3 +112,153 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } } + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template +static __global__ void im2col_3d_kernel( + const float * src, T * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW, + int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW, + int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) { + const int64_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= IC_KD_KH_KW) { + return; + } + GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH); + GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW); + + const int64_t iic = i / KD_KH_KW; + const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; + const int64_t ikw = i % KW; + + const int64_t iow = blockIdx.y; + for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; + + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iid = iod * s2 + ikd * d2 - p2; + + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); + dst[offset_dst] = src[offset_src]; + } + } +} + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template +static void im2col_3d_cuda(const float * src, T* dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { + const int64_t OH_OW = OH*OW; + const int64_t KD_KH_KW = KD*KH*KW; + const int64_t ID_IH_IW = ID*IH*IW; + const int64_t KH_KW = KH*KW; + const int64_t IH_IW = IH*IW; + const int64_t IC_KD_KH_KW = IC*KD*KH*KW; + const int64_t OW_KD_KH_KW = OW*KD*KH*KW; + const int64_t N_OD_OH = N*OD*OH; + const int64_t OD_OH = OD*OH; + const int64_t IC_ID_IH_IW = IC*ID*IH*IW; + const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; + dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + im2col_3d_kernel<<>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, + IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW, + OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2); +} + +static void im2col_3d_cuda_f16(const float * src, half * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { + + im2col_3d_cuda(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +static void im2col_3d_cuda_f32(const float * src, float * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { + + im2col_3d_cuda(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t OC = ne03 / IC; + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + + const size_t es = ggml_element_size(src1); + const int64_t stride_x = src1->nb[0] / es; + const int64_t stride_y = src1->nb[1] / es; + const int64_t stride_z = src1->nb[2] / es; + const int64_t stride_q = src1->nb[3] / es; + + if(dst->type == GGML_TYPE_F16) { + im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); + } else { + im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); + } +} diff --git a/ggml/src/ggml-cuda/im2col.cuh b/ggml/src/ggml-cuda/im2col.cuh index 1ce8fae4d9a..2da1223d634 100644 --- a/ggml/src/ggml-cuda/im2col.cuh +++ b/ggml/src/ggml-cuda/im2col.cuh @@ -3,3 +3,4 @@ #define CUDA_IM2COL_BLOCK_SIZE 256 void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 83ee16b27d0..c1f24243fe3 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -1,3 +1,4 @@ +#pragma once // This file contains primitives that expose the tensor core PTX instructions for CUDA code. // The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. // The documentation for the PTX instructions can be found under: @@ -291,9 +292,7 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); #else - GGML_UNUSED(t); - GGML_UNUSED(xs0); - GGML_UNUSED(stride); + GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -315,9 +314,7 @@ namespace ggml_cuda_mma { : "r"(A.x[1]), "r"(B.x[0])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -345,9 +342,7 @@ namespace ggml_cuda_mma { : "r"(A.x[3]), "r"(B.x[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -372,9 +367,7 @@ namespace ggml_cuda_mma { : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -408,9 +401,7 @@ namespace ggml_cuda_mma { : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -425,9 +416,7 @@ namespace ggml_cuda_mma { : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // AMPERE_MMA_AVAILABLE } @@ -452,9 +441,7 @@ namespace ggml_cuda_mma { : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -469,9 +456,7 @@ namespace ggml_cuda_mma { : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // AMPERE_MMA_AVAILABLE } @@ -505,9 +490,7 @@ namespace ggml_cuda_mma { : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -533,9 +516,7 @@ namespace ggml_cuda_mma { 0, 0, 0); #endif // defined(CDNA3) #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } @@ -561,9 +542,7 @@ namespace ggml_cuda_mma { 0, 0, 0); #endif // defined(CDNA3) #else - GGML_UNUSED(D); - GGML_UNUSED(A); - GGML_UNUSED(B); + GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 1437367e871..16331e9ecfa 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -1,344 +1,12 @@ #include "ggml.h" -#include "common.cuh" -#include "mma.cuh" #include "mmf.cuh" -using namespace ggml_cuda_mma; - -#define MMF_ROWS_PER_BLOCK 32 - -template -__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) -static __global__ void mul_mat_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - typedef tile<16, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - typedef tile<16, 8, float> tile_C; - - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - constexpr int tile_k_padded = warp_size + 4; - constexpr int ntA = rows_per_block / tile_A::I; - constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; - - const int row0 = blockIdx.x * rows_per_block; - const int channel_dst = blockIdx.y; - const int channel_x = channel_dst / channel_ratio; - const int channel_y = channel_dst; - const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; - const int sample_y = sample_dst; - - x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; - y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; - dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; - - const float2 * y2 = (const float2 *) y; - - extern __shared__ char data_mmv[]; - - tile_C C[ntA][ntB]; - - T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded); - - for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { - tile_A A[ntA][warp_size / tile_A::J]; -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { -#pragma unroll - for (int i = 0; i < tile_A::I; ++i) { - tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; - } -#pragma unroll - for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { - load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); - } - } - -#pragma unroll - for (int itB = 0; itB < ntB; ++itB) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int j0 = 0; j0 < tile_B::I; ++j0) { - const int j = j0 + itB*tile_B::I; - - tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; - } - } else if constexpr (std::is_same_v || std::is_same_v) { -#pragma unroll - for (int j0 = 0; j0 < tile_B::I; ++j0) { - const int j = j0 + itB*tile_B::I; - - const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); - tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; - } - } else { - static_assert(std::is_same_v, "unsupported type"); - } -#pragma unroll - for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { - tile_B B; - load_ldmatrix(B, tile_xy + k0, tile_k_padded); -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { - mma(C[itA][itB], A[itA][k0/tile_B::J], B); - } - } - } - } - - float * buf_iw = (float *) data_mmv; - constexpr int kiw = nwarps*rows_per_block + 4; - - if (nwarps > 1) { - __syncthreads(); - } -#pragma unroll - for (int itB = 0; itB < ntB; ++itB) { -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); - const int j = itB*tile_C::J + tile_C::get_j(l); - buf_iw[j*kiw + i] = C[itA][itB].x[l]; - } - } - } - - if (nwarps > 1) { - __syncthreads(); - } - -#pragma unroll - for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j0 + nwarps > cols_per_block && j >= cols_per_block) { - return; - } - - float sum = 0.0f; - static_assert(rows_per_block == warp_size, "need loop/check"); -#pragma unroll - for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { - const int i = i0 + threadIdx.x; - - sum += buf_iw[j*kiw + i]; - } - dst[j*stride_col_dst + row0 + threadIdx.x] = sum; - } -#else - NO_DEVICE_CODE; - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(ids); GGML_UNUSED(dst); - GGML_UNUSED(ncols); GGML_UNUSED(nchannels_y); GGML_UNUSED(stride_row); GGML_UNUSED(stride_col_y); GGML_UNUSED(stride_col_dst); - GGML_UNUSED(channel_ratio); GGML_UNUSED(stride_channel_x); GGML_UNUSED(stride_channel_y); GGML_UNUSED(stride_channel_dst); - GGML_UNUSED(sample_ratio); GGML_UNUSED(stride_sample_x); GGML_UNUSED(stride_sample_y); GGML_UNUSED(stride_sample_dst); -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) -} - -template -static void mul_mat_f_cuda( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols_x, const int64_t nrows_x, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - typedef tile<16, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - typedef tile<16, 8, float> tile_C; - - GGML_ASSERT(!ids && "mul_mat_id not implemented"); - - GGML_ASSERT(ncols_x % 2 == 0); - GGML_ASSERT(stride_row % 2 == 0); - GGML_ASSERT(stride_col_y % 2 == 0); - GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); - GGML_ASSERT( nsamples_dst % nsamples_x == 0); - const int64_t channel_ratio = nchannels_dst / nchannels_x; - const int64_t sample_ratio = nsamples_dst / nsamples_x; - - const int device = ggml_cuda_get_device(); - const int warp_size = ggml_cuda_info().devices[device].warp_size; - - int64_t nwarps_best = 1; - int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); - int64_t max_block_size = 256; - for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { - const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); - if (niter < niter_best) { - niter_best = niter; - nwarps_best = nwarps; - } - } - - constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; - const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; - const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; - const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); - const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst); - const dim3 block_dims(warp_size, nwarps_best, 1); - switch (nwarps_best) { - case 1: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 2: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 3: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 4: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 5: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 6: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 7: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 8: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - default: { - GGML_ABORT("fatal error"); - } break; - } -} - -template -static void mul_mat_f_switch_cols_per_block( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - switch (ncols_dst) { - case 1: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 2: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 3: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 4: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 5: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 6: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 7: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 8: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 9: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 10: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 11: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 12: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 13: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 14: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 15: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 16: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - default: { - GGML_ABORT("fatal error"); - } break; - } -} - void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_TENSOR_BINARY_OP_LOCALS; const size_t ts_src0 = ggml_type_size(src0->type); @@ -352,9 +20,6 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); GGML_ASSERT( nb0 == ts_dst); - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; - const float * src1_d = (const float *) src1->data; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; float * dst_d = (float *) dst->data; @@ -369,55 +34,72 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr const int64_t s13 = src1->nb[3] / ts_src1; const int64_t s3 = dst->nb[3] / ts_dst; + const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; + const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: const int64_t ncols_dst = ids ? ne2 : ne1; - const int64_t nchannels_y = ids ? ne11 : ne12; - const int64_t nchannels_dst = ids ? ne1 : ne2; - const int64_t stride_channel_dst = ids ? s1 : s2; - const int64_t stride_channel_y = ids ? s11 : s12; + const int64_t nchannels_dst = ids ? ne1 : ne2; + + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; + const int64_t stride_channel_dst = ids ? s1 : s2; - GGML_ASSERT(!ids || ncols_dst == 1); + int64_t stride_channel_y = ids ? s11 : s12; + int64_t nchannels_y = ids ? ne11 : ne12; + + //mul_mat_id: handle broadcast + if (ids && nchannels_y == 1) { + stride_channel_y = 0; + nchannels_y = ids->ne[0]; + } switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; constexpr int vals_per_T = 1; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; constexpr int vals_per_T = 2; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; constexpr int vals_per_T = 2; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); } } -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) { +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) { + + if (ggml_is_quantized(type)) { + return false; + } + if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) { return false; } if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { return false; } - if (ne11 > 16) { + if (src1_ncols > 16) { return false; } + switch (type) { case GGML_TYPE_F32: return ampere_mma_available(cc); diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 785f9f211c3..61e3bf30152 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -1,5 +1,461 @@ +#pragma once + +#include "mma.cuh" #include "common.cuh" +using namespace ggml_cuda_mma; + +#define MMF_ROWS_PER_BLOCK 32 + void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11); +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols); + +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f( + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int stride_col_id, const int stride_row_id, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int tile_k_padded = warp_size + 4; + constexpr int ntA = rows_per_block / tile_A::I; + constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; + + const int row0 = blockIdx.x * rows_per_block; + + const int expert_idx = has_ids ? blockIdx.y : 0; + const int channel_dst = has_ids ? 0 : blockIdx.y; + + const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio); + const int channel_y = channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; + y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y); + dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst); + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + + char * shmem_base = data_mmv; + int * slot_map = (int *) shmem_base; + char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base; + + tile_C C[ntA][ntB]; + + T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); + + if constexpr (has_ids) { + int found = 0; + + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + const int32_t * __restrict__ id_row = ids + j*stride_row_id; + + if (threadIdx.x == 0) { + slot_map[j] = -1; + } + + for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) { + int match = id_row[k*stride_col_id] == expert_idx; + + if (match) { + slot_map[j] = k; + found = 1; + break; + } + } + } + + if (!__syncthreads_or(found)) { + return; + } + } + + + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { + tile_A A[ntA][warp_size / tile_A::J]; +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int i = 0; i < tile_A::I; ++i) { + tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { + load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); + } + } + +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + if constexpr (!has_ids) { + tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; + } else { + tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f; + } + } + } else if constexpr (std::is_same_v || std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + if constexpr (!has_ids) { + const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } else { + float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + } + } + + float * buf_iw = (float *) compute_base; + constexpr int kiw = nwarps*rows_per_block + 4; + + if (nwarps > 1) { + __syncthreads(); + } +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); + const int j = itB*tile_C::J + tile_C::get_j(l); + buf_iw[j*kiw + i] = C[itA][itB].x[l]; + } + } + } + + if (nwarps > 1) { + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > cols_per_block && j >= cols_per_block) { + return; + } + + float sum = 0.0f; + static_assert(rows_per_block == warp_size, "need loop/check"); +#pragma unroll + for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { + const int i = i0 + threadIdx.x; + + sum += buf_iw[j*kiw + i]; + } + + if constexpr (!has_ids) { + dst[j*stride_col_dst + row0 + threadIdx.x] = sum; + } else { + const int slot = (j < cols_per_block) ? slot_map[j] : -1; + if (slot >= 0) { + dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum; + } + } + } +#else + GGML_UNUSED_VARS(x, y, ids, dst, + ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + NO_DEVICE_CODE; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +} + +template +static inline void mul_mat_f_switch_ids( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nchannels_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int64_t stride_row_id, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { + if (ids) { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } else { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } +} + +template +void mul_mat_f_cuda( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int64_t stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + + GGML_ASSERT(ncols_x % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(stride_col_y % 2 == 0); + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const int64_t channel_ratio = nchannels_dst / nchannels_x; + const int64_t sample_ratio = nsamples_dst / nsamples_x; + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t nwarps_best = 1; + int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); + int64_t max_block_size = 256; + for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { + const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); + if (niter < niter_best) { + niter_best = niter; + nwarps_best = nwarps; + } + } + + constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; + const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); + const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; + const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; + const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present + + const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); + const dim3 block_dims(warp_size, nwarps_best, 1); + + switch (nwarps_best) { + case 1: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 2: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 3: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 4: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 5: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 6: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 7: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 8: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } + + GGML_UNUSED_VARS(nchannels_y); +} + +template +static void mul_mat_f_switch_cols_per_block( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + switch (ncols_dst) { + case 1: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 2: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 3: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 4: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 5: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 6: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 7: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 8: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 9: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 10: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 11: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 12: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 13: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 14: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 15: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 16: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +#define DECL_MMF_CASE_HELPER(T, ncols_dst) \ + template void mul_mat_f_cuda( \ + const T * x, const float * y, const int32_t * ids, float * dst, \ + const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ + const int64_t stride_col_id, const int64_t stride_row_id, \ + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ + cudaStream_t stream); + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#define DECL_MMF_CASE_EXTERN(ncols_dst) \ + extern DECL_MMF_CASE_HELPER(float, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + +#define DECL_MMF_CASE(ncols_dst) \ + DECL_MMF_CASE_HELPER(float, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + +DECL_MMF_CASE_EXTERN(1); +DECL_MMF_CASE_EXTERN(2); +DECL_MMF_CASE_EXTERN(3); +DECL_MMF_CASE_EXTERN(4); +DECL_MMF_CASE_EXTERN(5); +DECL_MMF_CASE_EXTERN(6); +DECL_MMF_CASE_EXTERN(7); +DECL_MMF_CASE_EXTERN(8); +DECL_MMF_CASE_EXTERN(9); +DECL_MMF_CASE_EXTERN(10); +DECL_MMF_CASE_EXTERN(11); +DECL_MMF_CASE_EXTERN(12); +DECL_MMF_CASE_EXTERN(13); +DECL_MMF_CASE_EXTERN(14); +DECL_MMF_CASE_EXTERN(15); +DECL_MMF_CASE_EXTERN(16); +#else +#define DECL_MMF_CASE(ncols_dst) +#endif diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 384ee7615f7..714b23f9f49 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -3,6 +3,140 @@ #include +// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. +struct mmq_ids_helper_store { + uint32_t data; + + __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) { + data = (it & 0x003FFFFF) | (iex_used << 22); + } + + __device__ uint32_t it() const { + return data & 0x003FFFFF; + } + + __device__ uint32_t iex_used() const { + return data >> 22; + } +}; +static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store"); + +// Helper function for mul_mat_id, converts ids to a more convenient format. +// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. +// ids_dst describes the same mapping but for the dst tensor. +// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. +template +__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mmq_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; + const int expert = blockIdx.x; + + extern __shared__ char data_mmq_ids_helper[]; + mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper; + + int nex_prev = 0; // Number of columns for experts with a lower index. + int it_compact = 0; // Running index for the compact slice of this expert. + + if constexpr (n_expert_used_template == 0) { + // Generic implementation: + for (int it = 0; it < n_tokens; ++it) { + int iex_used = -1; // The index at which the expert is used, if any. + for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { + const int expert_used = ids[it*si1 + iex]; + nex_prev += expert_used < expert; + if (expert_used == expert) { + iex_used = iex; + } + } + + if (iex_used != -1) { + store[it_compact] = mmq_ids_helper_store(it, iex_used); + } + + if (warp_reduce_any(iex_used != -1)) { + it_compact++; + } + } + } else { + // Implementation optimized for specific numbers of experts used: + static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); + const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. + for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { + const int it = it0 + threadIdx.x / neu_padded; + + const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. + const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? + ids[it*si1 + iex] : INT_MAX; + const int iex_used = expert_used == expert ? iex : -1; + nex_prev += expert_used < expert; + + // Whether the threads at this token position have used the expert: + const int it_compact_add_self = warp_reduce_any(iex_used != -1); + + // Do a scan over threads at lower token positions in warp to get the correct index for writing data: + int it_compact_add_lower = 0; +#pragma unroll + for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { + const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); + if (threadIdx.x >= offset) { + it_compact_add_lower += tmp; + } + } + + if (iex_used != -1) { + store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used); + } + + // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: + it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); + } + } + nex_prev = warp_reduce_sum(nex_prev); + + for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { + const mmq_ids_helper_store store_it = store[itc]; + const int it = store_it.it(); + const int iex_used = store_it.iex_used(); + ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; + ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; + } + + if (threadIdx.x != 0) { + return; + } + + expert_bounds[expert] = nex_prev; + + if (expert < gridDim.x - 1) { + return; + } + + expert_bounds[gridDim.x] = nex_prev + it_compact; +} + +template +static void launch_mmq_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store"); + GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store"); + + const int id = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper, smpbo); + + const dim3 num_blocks(n_experts, 1, 1); + const dim3 block_size(warp_size, 1, 1); + const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store); + GGML_ASSERT(nbytes_shared <= smpbo); + mmq_ids_helper<<>> + (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); +} + static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { case GGML_TYPE_Q4_0: @@ -137,7 +271,7 @@ void ggml_cuda_mul_mat_q( ne00, ne01, ne1, s01, ne11, s1, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, - use_stream_k}; + use_stream_k, ne1}; ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); return; } @@ -148,53 +282,49 @@ void ggml_cuda_mul_mat_q( const int64_t n_expert_used = ids->ne[0]; const int64_t ne_get_rows = ne12 * n_expert_used; + GGML_ASSERT(ne1 == n_expert_used); - std::vector ids_host(ggml_nbytes(ids)); - std::vector ids_src1_host; - ids_src1_host.reserve(ne_get_rows); - std::vector ids_dst_host; - ids_dst_host.reserve(ne_get_rows); - std::vector tokens_per_expert_host(ne02); - std::vector expert_bounds_host(ne02 + 1); - ggml_cuda_pool_alloc ids_buf_dev(ctx.pool()); - - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); - - for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices - for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens - for (int64_t iex = 0; iex < n_expert_used; ++iex) { - const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]); - assert(expert_to_use >= 0 && expert_to_use < ne02); - if (expert_to_use == i02) { - ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11); - ids_dst_host.push_back(i12*ne1 + iex); - tokens_per_expert_host[i02]++; - break; - } - } - } - } + ggml_cuda_pool_alloc ids_src1(ctx.pool(), ne_get_rows); + ggml_cuda_pool_alloc ids_dst(ctx.pool(), ne_get_rows); + ggml_cuda_pool_alloc expert_bounds(ctx.pool(), ne02 + 1); - int32_t cumsum = 0; - for (int64_t i = 0; i < ne02; ++i) { - expert_bounds_host[i] = cumsum; - cumsum += tokens_per_expert_host[i]; + { + GGML_ASSERT(ids->nb[0] == ggml_element_size(ids)); + const int si1 = ids->nb[1] / ggml_element_size(ids); + const int sis1 = nb12 / nb11; + + switch (n_expert_used) { + case 2: + launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 4: + launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 6: + launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 8: + launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 16: + launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 32: + launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + default: + launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + } + CUDA_CHECK(cudaGetLastError()); } - expert_bounds_host[ne02] = cumsum; - - std::vector ids_buf_host; - ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size()); - ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end()); - ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end()); - ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end()); - ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device. - CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); - - const int32_t * ids_src1_dev = ids_buf_dev.ptr; - const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size(); - const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size(); const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); @@ -208,7 +338,7 @@ void ggml_cuda_mul_mat_q( const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[2] / ts_src1; - quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type, + quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); CUDA_CHECK(cudaGetLastError()); } @@ -218,11 +348,11 @@ void ggml_cuda_mul_mat_q( // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. const mmq_args args = { - src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d, + src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d, ne00, ne01, ne_get_rows, s01, ne_get_rows, s1, ne02, ne02, s02, s12, s2, ne03, ne13, s03, s13, s3, - use_stream_k}; + use_stream_k, ne12}; ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); } @@ -262,14 +392,11 @@ void ggml_cuda_op_mul_mat_q( ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, - use_stream_k}; + use_stream_k, src1_ncols}; ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_ddf_i); - GGML_UNUSED(src1_padded_row_size); + GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_padded_row_size); } bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 96129bd831f..c9a07e82fed 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -1255,7 +1255,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } } #else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); + GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } @@ -1572,7 +1572,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } } #else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); + GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } @@ -2301,7 +2301,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } } #else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); + GGML_UNUSED_VARS(x, y, sum, k00); NO_DEVICE_CODE; #endif // AMD_MFMA_AVAILABLE } @@ -2855,12 +2855,14 @@ static __device__ __forceinline__ void mmq_write_back_mma( #else typedef tile<16, 8, int> tile_C; constexpr int rows_per_warp = 2 * granularity; -#endif +#endif // defined(AMD_MFMA_AVAILABLE) constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); #if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); +#else + GGML_UNUSED(nwarps); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) #pragma unroll @@ -3136,7 +3138,8 @@ static __global__ void mul_mat_q( const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ncols_max) { // Skip unused template specializations for faster compilation: if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { @@ -3150,7 +3153,7 @@ static __global__ void mul_mat_q( constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); - const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x + const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y // Initialize the ids for writing back data with just the index. @@ -3374,7 +3377,8 @@ template static __global__ void mul_mat_q_stream_k_fixup( const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, - const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) { + const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, + const int ncols_max) { constexpr int mmq_y = get_mmq_y_device(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -3385,7 +3389,7 @@ static __global__ void mul_mat_q_stream_k_fixup( float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; - const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; + const int ntx = (ncols_max + mmq_x - 1) / mmq_x; const int nty = (nrows_x + mmq_y - 1) / mmq_y; const int bidx0 = blockIdx.x; @@ -3526,7 +3530,7 @@ struct mmq_args { int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst; int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst; - bool use_stream_k; + bool use_stream_k; int64_t ncols_max; }; template @@ -3556,7 +3560,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; - const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x; + const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x; const int ntzw = args.nchannels_y * args.nsamples_y; const dim3 block_nums_xy_tiling(nty, ntx, ntzw); @@ -3572,14 +3576,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); } return; } @@ -3599,7 +3605,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); if (!fixup_needed) { return; @@ -3607,14 +3614,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a mul_mat_q_stream_k_fixup<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); if (!fixup_needed) { return; @@ -3622,7 +3631,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a mul_mat_q_stream_k_fixup<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); } } @@ -3647,7 +3657,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda continue; } - const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x; + const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x; if (ntiles_x < ntiles_x_best) { mmq_x_best = mmq_x; diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu deleted file mode 100644 index e14c93516bd..00000000000 --- a/ggml/src/ggml-cuda/mmv.cu +++ /dev/null @@ -1,506 +0,0 @@ -#include "ggml.h" -#include "common.cuh" -#include "mmv.cuh" - -template -static __global__ void mul_mat_vec( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { - const int row = blockIdx.x; - const int channel_dst = blockIdx.y; - const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; - const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; - const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; - const int sample_y = sample_dst; - const int tid = threadIdx.x; - - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - - x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; - y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; - dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; - - const float2 * y2 = (const float2 *) y; - - extern __shared__ char data_mmv[]; - float * buf_iw = (float *) data_mmv; - - if (block_size > warp_size) { - if (tid < warp_size) { - buf_iw[tid] = 0.0f; - } - __syncthreads(); - } - - float sumf[ncols_dst] = {0.0f}; - - if constexpr (std::is_same::value) { - const float2 * x2 = (const float2 *) x; - - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const float2 tmpx = x2[col2]; - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += tmpx.x*tmpy.x; - sumf[j] += tmpx.y*tmpy.y; - } - } - } else if constexpr (std::is_same::value) { - const half2 * x2 = (const half2 *) x; - - if (std::is_same::value) { - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const float2 tmpx = __half22float2(x2[col2]); - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += tmpx.x * tmpy.x; - sumf[j] += tmpx.y * tmpy.y; - } - } - } else { -#ifdef FP16_AVAILABLE - half2 sumh2[ncols_dst] = {{0.0f, 0.0f}}; - - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const half2 tmpx = x2[col2]; - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y); - } - } - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]); - } -#else - NO_DEVICE_CODE; -#endif // FP16_AVAILABLE - } - } else if constexpr (std::is_same::value) { - const int * x2 = (const int *) x; - for (int col2 = tid; col2 < ncols2; col2 += block_size) { - const int tmpx = x2[col2]; -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += float(reinterpret_cast(&tmpx)[0]) * tmpy.x; - sumf[j] += float(reinterpret_cast(&tmpx)[1]) * tmpy.y; - } - } - } else { - static_assert(std::is_same::value, "unsupported type"); - } - -#pragma unroll - for (int j = 0; j < ncols_dst; ++j) { - sumf[j] = warp_reduce_sum(sumf[j]); - - if (block_size > warp_size) { - buf_iw[tid/warp_size] = sumf[j]; - __syncthreads(); - if (tid < warp_size) { - sumf[j] = buf_iw[tid]; - sumf[j] = warp_reduce_sum(sumf[j]); - } - if (j < ncols_dst) { - __syncthreads(); - } - } - } - - if (tid >= ncols_dst) { - return; - } - - dst[tid*stride_col_dst + row] = sumf[tid]; -} - -template -static void launch_mul_mat_vec_cuda( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols, const int64_t nrows, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - GGML_ASSERT(ncols % 2 == 0); - GGML_ASSERT(stride_row % 2 == 0); - GGML_ASSERT(stride_col_y % 2 == 0); - GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); - GGML_ASSERT( nsamples_dst % nsamples_x == 0); - const int64_t channel_ratio = nchannels_dst / nchannels_x; - const int64_t sample_ratio = nsamples_dst / nsamples_x; - int device; - int warp_size; - - CUDA_CHECK(cudaGetDevice(&device)); - warp_size = ggml_cuda_info().devices[device].warp_size; - - int64_t block_size_best = warp_size; - int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size); - int64_t max_block_size = 256; - if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) { - max_block_size = 128; - } - for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) { - const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); - if (niter < niter_best) { - niter_best = niter; - block_size_best = block_size; - } - } - - const int smem = warp_size*sizeof(float); - const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); - const dim3 block_dims(block_size_best, 1, 1); - switch (block_size_best) { - case 32: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 64: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 96: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 128: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 160: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 192: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 224: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 256: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - default: { - GGML_ABORT("fatal error"); - } break; - } -} - -template -static void mul_mat_vec_cuda_switch_ncols_dst( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - switch (ncols_dst) { - case 1: - launch_mul_mat_vec_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 2: - launch_mul_mat_vec_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 3: - launch_mul_mat_vec_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 4: - launch_mul_mat_vec_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 5: - launch_mul_mat_vec_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 6: - launch_mul_mat_vec_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 7: - launch_mul_mat_vec_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - case 8: - launch_mul_mat_vec_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - break; - default: - GGML_ABORT("fatal error"); - break; - } -} - -template -static void mul_mat_vec_cuda( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, - const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - enum ggml_prec prec, cudaStream_t stream) { - if constexpr(std::is_same::value) { - if (prec == GGML_PREC_DEFAULT) { - mul_mat_vec_cuda_switch_ncols_dst - (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - return; - } - } - mul_mat_vec_cuda_switch_ncols_dst - (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); -} - -void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { - GGML_ASSERT( src1->type == GGML_TYPE_F32); - GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - GGML_TENSOR_BINARY_OP_LOCALS; - - const size_t ts_src0 = ggml_type_size(src0->type); - const size_t ts_src1 = ggml_type_size(src1->type); - const size_t ts_dst = ggml_type_size(dst->type); - - GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. - GGML_ASSERT(ne13 == ne3); - - GGML_ASSERT( nb00 == ts_src0); - GGML_ASSERT( nb10 == ts_src1); - GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); - GGML_ASSERT( nb0 == ts_dst); - - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; - - const float * src1_d = (const float *) src1->data; - const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; - float * dst_d = (float *) dst->data; - - const int64_t s01 = src0->nb[1] / ts_src0; - const int64_t s11 = src1->nb[1] / ts_src1; - const int64_t s1 = dst->nb[1] / ts_dst; - const int64_t s02 = src0->nb[2] / ts_src0; - const int64_t s12 = src1->nb[2] / ts_src1; - const int64_t s2 = dst->nb[2] / ts_dst; - const int64_t s03 = src0->nb[3] / ts_src0; - const int64_t s13 = src1->nb[3] / ts_src1; - const int64_t s3 = dst->nb[3] / ts_dst; - - // For MUL_MAT_ID the memory layout is different than for MUL_MAT: - const int64_t ncols_dst = ids ? ne2 : ne1; - const int64_t nchannels_y = ids ? ne11 : ne12; - const int64_t nchannels_dst = ids ? ne1 : ne2; - const int64_t stride_channel_dst = ids ? s1 : s2; - const int64_t stride_channel_y = ids ? s11 : s12; - - GGML_ASSERT(!ids || ncols_dst == 1); - - switch (src0->type) { - case GGML_TYPE_F32: { - const float * src0_d = (const float *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, - ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); - } break; - case GGML_TYPE_F16: { - const half * src0_d = (const half *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, - ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); - } break; - case GGML_TYPE_BF16: { - const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, - ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); - } break; - default: - GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); - } -} - -void ggml_cuda_op_mul_mat_vec( - ggml_backend_cuda_context & ctx, - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, - const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, cudaStream_t stream) { - - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne10 = src1->ne[0]; - const int64_t ne0 = dst->ne[0]; - const int64_t row_diff = row_high - row_low; - - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; - - - // ggml_cuda_op provides single, contiguous matrices - const int64_t stride_row = ne00; - const int64_t stride_col_y = ne10; - const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer - const int64_t nchannels_x = 1; - const int64_t nchannels_y = 1; - const int64_t nchannels_dst = 1; - const int64_t stride_channel_x = 0; - const int64_t stride_channel_y = 0; - const int64_t stride_channel_dst = 0; - const int64_t nsamples_x = 1; - const int64_t nsamples_dst = 1; - const int64_t stride_sample_x = 0; - const int64_t stride_sample_y = 0; - const int64_t stride_sample_dst = 0; - - switch (src0->type) { - case GGML_TYPE_F32: { - const float * src0_d = (const float *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); - } break; - case GGML_TYPE_F16: { - const half * src0_d = (const half *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); - } break; - case GGML_TYPE_BF16: { - const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); - } break; - default: - GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); - } - - GGML_UNUSED(ctx); - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_ddq_i); - GGML_UNUSED(src1_ncols); - GGML_UNUSED(src1_padded_row_size); -} - -bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { - if (src0_ne[0] % 2 != 0) { - return false; - } - switch (type) { - case GGML_TYPE_F32: - if (GGML_CUDA_CC_IS_NVIDIA(cc)) { - if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { - return ne11 <= 8; - } - if (cc >= GGML_CUDA_CC_TURING) { - return ne11 <= 4; - } - return ne11 <= 3; - } else if (GGML_CUDA_CC_IS_AMD(cc)) { - if (fp32_mma_hardware_available(cc)) { - return ne11 <= 3; - } - return ne11 <= 8; - } - return ne11 <= 8; - case GGML_TYPE_F16: - if (GGML_CUDA_CC_IS_NVIDIA(cc)) { - const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); - if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { - return src0_small && ne11 <= 4; - } - if (fp16_mma_hardware_available(cc)) { - return src0_small && ne11 <= 3; - } - return ne11 <= 8; - } else if (GGML_CUDA_CC_IS_AMD(cc)) { - if (fp16_mma_hardware_available(cc)) { - if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { - return ne11 <= 5; - } - return ne11 <= 2; - } - return ne11 <= 8; - } - return ne11 <= 8; - case GGML_TYPE_BF16: - if (GGML_CUDA_CC_IS_NVIDIA(cc)) { - const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); - if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { - return src0_small && ne11 <= 4; - } - if (bf16_mma_hardware_available(cc)) { - return src0_small && ne11 <= 3; - } - return ne11 <= 8; - } else if (GGML_CUDA_CC_IS_AMD(cc)) { - if (bf16_mma_hardware_available(cc)) { - return ne11 <= 3; - } - return ne11 <= 8; - } - return ne11 <= 8; - default: - return false; - } -} diff --git a/ggml/src/ggml-cuda/mmv.cuh b/ggml/src/ggml-cuda/mmv.cuh deleted file mode 100644 index 1330bcb6a88..00000000000 --- a/ggml/src/ggml-cuda/mmv.cuh +++ /dev/null @@ -1,11 +0,0 @@ -#include "common.cuh" - -void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); - -void ggml_cuda_op_mul_mat_vec( - ggml_backend_cuda_context & ctx, - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, - const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, cudaStream_t stream); - -bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11); diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 16100b68045..5b21ef05b3c 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -433,12 +433,7 @@ void ggml_cuda_op_mul_mat_vec_f( GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); } - GGML_UNUSED(ctx); - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_ddq_i); - GGML_UNUSED(src1_ncols); - GGML_UNUSED(src1_padded_row_size); + GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size); } bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 5c8e5c4a7ee..52de4e78d13 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -141,9 +141,10 @@ template __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; @@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q( constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. - const int channel_dst = blockIdx.y; - const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio; - const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst; - const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; - const int sample_y = sample_dst; + const uint32_t channel_dst = blockIdx.y; + const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + const uint32_t sample_dst = blockIdx.z; + const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); + const uint32_t sample_y = sample_dst; // partial sum for each thread float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; @@ -247,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst( GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); - const int channel_ratio = nchannels_dst / nchannels_x; - const int sample_ratio = nsamples_dst / nsamples_x; + const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); const int warp_size = ggml_cuda_info().devices[device].warp_size; @@ -256,86 +258,70 @@ static void mul_mat_vec_q_switch_ncols_dst( GGML_ASSERT(!ids || ncols_dst == 1); switch (ncols_dst) { - case 1: - { + case 1: { constexpr int c_ncols_dst = 1; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 2: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 2: { constexpr int c_ncols_dst = 2; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 3: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 3: { constexpr int c_ncols_dst = 3; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 4: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 4: { constexpr int c_ncols_dst = 4; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 5: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 5: { constexpr int c_ncols_dst = 5; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 6: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 6: { constexpr int c_ncols_dst = 6; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 7: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 7: { constexpr int c_ncols_dst = 7; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 8: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 8: { constexpr int c_ncols_dst = 8; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; default: GGML_ABORT("fatal error"); break; @@ -596,9 +582,5 @@ void ggml_cuda_op_mul_mat_vec_q( src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_ddf_i); - GGML_UNUSED(src1_ncols); - GGML_UNUSED(src1_padded_row_size); + GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); } diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index bddcca51b7b..4f153c5718e 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -104,12 +104,30 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } } -template -static __global__ void rms_norm_f32( - const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, - const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0, - const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0, - const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) { +template +static __global__ void rms_norm_f32(const float * x, + float * dst, + const int ncols, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const float eps, + const float * mul = nullptr, + const int64_t mul_stride_row = 0, + const int64_t mul_stride_channel = 0, + const int64_t mul_stride_sample = 0, + const uint3 mul_ncols_packed = make_uint3(0, 0, 0), + const uint3 mul_nrows_packed = make_uint3(0, 0, 0), + const uint3 mul_nchannels_packed = make_uint3(0, 0, 0), + const uint3 mul_nsamples_packed = make_uint3(0, 0, 0), + const float * add = nullptr, + const int64_t add_stride_row = 0, + const int64_t add_stride_channel = 0, + const int64_t add_stride_sample = 0, + const uint3 add_ncols_packed = make_uint3(0, 0, 0), + const uint3 add_nrows_packed = make_uint3(0, 0, 0), + const uint3 add_nchannels_packed = make_uint3(0, 0, 0), + const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) { const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -118,14 +136,23 @@ static __global__ void rms_norm_f32( const int sample = blockIdx.z; const int tid = threadIdx.x; + static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying"); + x += sample*stride_sample + channel*stride_channel + row*stride_row; dst += ((sample*nchannels + channel)*nrows + row)*ncols; if constexpr (do_multiply) { - const int mul_row = row % mul_nrows; - const int mul_channel = channel % mul_nchannels; - const int mul_sample = sample % mul_nsamples; - mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row; + const uint32_t mul_row = fastmodulo(row, mul_nrows_packed); + const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed); + const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed); + mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row; + } + + if constexpr (do_add) { + const int add_row = fastmodulo(row, add_nrows_packed); + const int add_channel = fastmodulo(channel, add_nchannels_packed); + const int add_sample = fastmodulo(sample, add_nsamples_packed); + add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row; } float tmp = 0.0f; // partial sum for thread in warp @@ -138,15 +165,18 @@ static __global__ void rms_norm_f32( // sum up partial sums tmp = warp_reduce_sum(tmp); if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); + static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size"); __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } __syncthreads(); - tmp = s_sum[lane_id]; + tmp = 0.0f; + if (lane_id < (block_size / WARP_SIZE)) { + tmp = s_sum[lane_id]; + } tmp = warp_reduce_sum(tmp); } @@ -154,9 +184,13 @@ static __global__ void rms_norm_f32( const float scale = rsqrtf(mean + eps); for (int col = tid; col < ncols; col += block_size) { - if constexpr (do_multiply) { - const int mul_col = col % mul_ncols; - dst[col] = scale * x[col] * mul[mul_col]; + if constexpr (do_multiply && do_add) { + const int mul_col = fastmodulo(col, mul_ncols_packed); + const int add_col = fastmodulo(col, add_ncols_packed); + dst[col] = scale * x[col] * mul[mul_col] + add[add_col]; + } else if constexpr (do_multiply) { + const int mul_col = fastmodulo(col, mul_ncols_packed); + dst[col] = scale * x[col] * mul[mul_col]; } else { dst[col] = scale * x[col]; } @@ -323,31 +357,87 @@ static void rms_norm_f32_cuda( const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } -static void rms_norm_mul_f32_cuda( - const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, - const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, - const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample, - const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples, - const float eps, cudaStream_t stream) { +static void rms_norm_mul_f32_cuda(const float * x, + const float * mul, + const float * add, + float * dst, + const int ncols, + const int nrows, + const int nchannels, + const int nsamples, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const int64_t mul_stride_row, + const int64_t mul_stride_channel, + const int64_t mul_stride_sample, + const uint32_t mul_ncols, + const uint32_t mul_nrows, + const uint32_t mul_nchannels, + const uint32_t mul_nsamples, + const int64_t add_stride_row, + const int64_t add_stride_channel, + const int64_t add_stride_sample, + const uint32_t add_ncols, + const uint32_t add_nrows, + const uint32_t add_nchannels, + const uint32_t add_nsamples, + const float eps, + cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); if (mul == nullptr) { rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); return; } - if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); + if (add == nullptr) { + const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); + const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); + const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); + const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); + if (ncols < 1024) { + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + } } else { - const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); + const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); + const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); + const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); + const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); + + const uint3 add_ncols_packed = init_fastdiv_values(add_ncols); + const uint3 add_nrows_packed = init_fastdiv_values(add_nrows); + const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels); + const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); + if (ncols < 1024) { + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, true, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024, true, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); + } } } @@ -491,7 +581,102 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * const int mul_nchannels = mul_src->ne[2]; const int mul_nsamples = mul_src->ne[3]; - rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream); + rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d, + ne00, ne01, ne02, ne03, + /*s00*/ s01, s02, s03, + /*mul_s00*/ mul_s01, mul_s02, mul_s03, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, + /*add_s00*/ 0, 0, 0, + 0, 0, 0, 0, + eps, stream); +} + +void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx, + ggml_tensor * dst, + ggml_tensor * mul_tensor, + ggml_tensor * add_tensor) { + const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0]; + float eps = 0.0f; + + memcpy(&eps, dst->op_params, sizeof(float)); + + const float * src0_d = (const float *) rms_norm_src->data; + const float * mul_d = nullptr; + const ggml_tensor * mul_src = nullptr; + + if (mul_tensor->src[0] == dst) { + mul_d = (float *) mul_tensor->src[1]->data; + mul_src = mul_tensor->src[1]; + } else if (mul_tensor->src[1] == dst) { + mul_d = (float *) mul_tensor->src[0]->data; + mul_src = mul_tensor->src[0]; + } else { + GGML_ASSERT(false); + } + + const float * add_d = nullptr; + const ggml_tensor * add_src = nullptr; + + if (add_tensor->src[0] == mul_tensor) { + add_d = (float *) add_tensor->src[1]->data; + add_src = add_tensor->src[1]; + } else if (add_tensor->src[1] == mul_tensor) { + add_d = (float *) add_tensor->src[0]->data; + add_src = add_tensor->src[0]; + } else { + GGML_ASSERT(false); + } + + float * dst_d = (float *) add_tensor->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32); + GGML_ASSERT(add_tensor->type == GGML_TYPE_F32); + GGML_ASSERT(eps >= 0.0f); + + const int64_t ne00 = rms_norm_src->ne[0]; + const int64_t ne01 = rms_norm_src->ne[1]; + const int64_t ne02 = rms_norm_src->ne[2]; + const int64_t ne03 = rms_norm_src->ne[3]; + + const size_t ts0 = ggml_type_size(rms_norm_src->type); + GGML_ASSERT(rms_norm_src->nb[0] == ts0); + const int64_t s01 = rms_norm_src->nb[1] / ts0; + const int64_t s02 = rms_norm_src->nb[2] / ts0; + const int64_t s03 = rms_norm_src->nb[3] / ts0; + + const size_t ts_mul = ggml_type_size(mul_src->type); + GGML_ASSERT(mul_src->nb[0] == ts_mul); + const int64_t mul_s01 = mul_src->nb[1] / ts_mul; + const int64_t mul_s02 = mul_src->nb[2] / ts_mul; + const int64_t mul_s03 = mul_src->nb[3] / ts_mul; + + const int mul_ncols = mul_src->ne[0]; + const int mul_nrows = mul_src->ne[1]; + const int mul_nchannels = mul_src->ne[2]; + const int mul_nsamples = mul_src->ne[3]; + + const size_t ts_add = ggml_type_size(add_src->type); + GGML_ASSERT(add_src->nb[0] == ts_add); + const int64_t add_s01 = add_src->nb[1] / ts_add; + const int64_t add_s02 = add_src->nb[2] / ts_add; + const int64_t add_s03 = add_src->nb[3] / ts_add; + + const int add_ncols = add_src->ne[0]; + const int add_nrows = add_src->ne[1]; + const int add_nchannels = add_src->ne[2]; + const int add_nsamples = add_src->ne[3]; + + rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d, + ne00,ne01, ne02, ne03, + /*s00*/ s01, s02, s03, + /*mul_s00*/ mul_s01, mul_s02, mul_s03, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, + /*add_s00*/ add_s01, add_s02, add_s03, + add_ncols, add_nrows, add_nchannels, add_nsamples, + eps, stream); } void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 7ea7bd4df3c..a74f6376720 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -8,6 +8,11 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor); +void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx, + ggml_tensor * dst, + ggml_tensor * mul_tensor, + ggml_tensor * add_tensor); + void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu index 77432b04689..29aef33c1a4 100644 --- a/ggml/src/ggml-cuda/pad.cu +++ b/ggml/src/ggml-cuda/pad.cu @@ -1,36 +1,50 @@ #include "pad.cuh" -static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) { - // blockIdx.z: idx of ne2*ne3, aka ne02*ne03 - // blockIdx.y: idx of ne1 - // blockIDx.x: idx of ne0 / BLOCK_SIZE - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { +static __global__ void pad_f32(const float * src, float * dst, + const int lp0, const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, const int rp3, + const int ne0, const int ne1, const int ne2, const int ne3) { + // blockIdx.z: i3*ne2+i2 + // blockIdx.y: i1 + // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE + // gridDim.y: ne1 + int i0 = threadIdx.x + blockIdx.x * blockDim.x; + int i1 = blockIdx.y; + int i2 = blockIdx.z % ne2; + int i3 = blockIdx.z / ne2; + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { return; } // operation - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) { - int offset_src = - nidx + - blockIdx.y * ne00 + - blockIdx.z * ne00 * ne01; - dst[offset_dst] = x[offset_src]; + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + if ((i0 >= lp0 && i0 < ne0 - rp0) && + (i1 >= lp1 && i1 < ne1 - rp1) && + (i2 >= lp2 && i2 < ne2 - rp2) && + (i3 >= lp3 && i3 < ne3 - rp3)) { + const int64_t i00 = i0 - lp0; + const int64_t i01 = i1 - lp1; + const int64_t i02 = i2 - lp2; + const int64_t i03 = i3 - lp3; + const int64_t ne02 = ne2 - lp2 - rp2; + const int64_t ne01 = ne1 - lp1 - rp1; + const int64_t ne00 = ne0 - lp0 - rp0; + + const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00; + + dst[dst_idx] = src[src_idx]; } else { - dst[offset_dst] = 0.0f; + dst[dst_idx] = 0.0f; } } -static void pad_f32_cuda(const float * x, float * dst, - const int ne00, const int ne01, const int ne02, const int ne03, +static void pad_f32_cuda(const float * src, float * dst, + const int lp0, const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; dim3 gridDim(num_blocks, ne1, ne2*ne3); - pad_f32<<>>(x, dst, ne0, ne00, ne01, ne02, ne03); + pad_f32<<>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3); } void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -41,9 +55,18 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int32_t lp0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t rp0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t lp1 = ((const int32_t*)(dst->op_params))[2]; + const int32_t rp1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t lp2 = ((const int32_t*)(dst->op_params))[4]; + const int32_t rp2 = ((const int32_t*)(dst->op_params))[5]; + const int32_t lp3 = ((const int32_t*)(dst->op_params))[6]; + const int32_t rp3 = ((const int32_t*)(dst->op_params))[7]; pad_f32_cuda(src0_d, dst_d, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); } diff --git a/ggml/src/ggml-cuda/pad_reflect_1d.cu b/ggml/src/ggml-cuda/pad_reflect_1d.cu new file mode 100644 index 00000000000..0478889da13 --- /dev/null +++ b/ggml/src/ggml-cuda/pad_reflect_1d.cu @@ -0,0 +1,89 @@ +#include "pad_reflect_1d.cuh" + +static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void + pad_reflect_1d_kernel_f32( + const void * __restrict__ src0, + void * __restrict__ dst, + const int64_t ne0, + const int64_t ne00, + const uint3 ne01, + const int64_t ne02, + const int64_t ne03, + const int64_t nb00, + const int64_t nb01, + const int64_t nb02, + const int64_t nb03, + const int64_t nb0, + const int64_t nb1, + const int64_t nb2, + const int64_t nb3, + const int p0, + const int p1) { + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + + const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01); + const int64_t tile1 = div_mod_packed.y; // i1 + const int64_t tile0 = div_mod_packed.x; // nth i0 tile + const int64_t i1 = tile1; + const int64_t i0 = threadIdx.x + tile0 * blockDim.x; + + // ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh) + if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) { + return; + } + + const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01; + char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1; + + const int64_t rel_i0 = i0 - p0; // relative i0 in src0 + int64_t src_idx; + + if (rel_i0 < 0) { + // Left padding - reflect + src_idx = -rel_i0; + } else if (rel_i0 < ne00) { + // Middle - copy + src_idx = rel_i0; + } else { + // Right padding - reflect + src_idx = 2 * ne00 - 2 - rel_i0; + } + const float value = *(const float *) (src0_ptr + src_idx * nb00); + *(float *) (dst_ptr + i0 * nb0) = value; +} + +void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int32_t * opts = (const int32_t *) dst->op_params; + const int p0 = opts[0]; + const int p1 = opts[1]; + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const uint3 ne01_packed = init_fastdiv_values(ne01); + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + + // sanity: padded length matches + GGML_ASSERT(ne0 == ne00 + p0 + p1); + + constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x) + const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0 + // grid.x covers i1 and all tiles of i0: [ne01 * tiles0] + // grid.y covers i2: [ne02] + // grid.z covers i3: [ne03] + const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03); + const dim3 block_dims((unsigned) bx, 1, 1); + + pad_reflect_1d_kernel_f32<<>>( + src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1); +} diff --git a/ggml/src/ggml-cuda/pad_reflect_1d.cuh b/ggml/src/ggml-cuda/pad_reflect_1d.cuh new file mode 100644 index 00000000000..15f2ed1737b --- /dev/null +++ b/ggml/src/ggml-cuda/pad_reflect_1d.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256 + +void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index a0b03a740d7..5117f9ffc0f 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -1,26 +1,27 @@ #include "quantize.cuh" #include +__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1) static __global__ void quantize_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, - const int64_t ne0, const int ne1, const int ne2) { + const int64_t ne0, const uint32_t ne1, const uint3 ne2) { const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= ne0) { return; } + const int64_t i3 = fastdiv(blockIdx.z, ne2); + const int64_t i2 = blockIdx.z - i3*ne2.z; const int64_t i1 = blockIdx.y; - const int64_t i2 = blockIdx.z % ne2; - const int64_t i3 = blockIdx.z / ne2; const int64_t & i00 = i0; const int64_t & i01 = i1; const int64_t & i02 = i2; const int64_t & i03 = i3; - const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0; + const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0; block_q8_1 * y = (block_q8_1 *) vy; @@ -31,10 +32,10 @@ static __global__ void quantize_q8_1( float amax = fabsf(xi); float sum = xi; - amax = warp_reduce_max(amax); - sum = warp_reduce_sum(sum); + amax = warp_reduce_max(amax); + sum = warp_reduce_sum(sum); - const float d = amax / 127; + const float d = amax / 127.0f; const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); y[ib].qs[iqs] = q; @@ -43,8 +44,7 @@ static __global__ void quantize_q8_1( return; } - reinterpret_cast(y[ib].ds.x) = d; - reinterpret_cast(y[ib].ds.y) = sum; + y[ib].ds = make_half2(d, sum); } template @@ -152,10 +152,12 @@ void quantize_row_q8_1_cuda( GGML_ASSERT(!ids); GGML_ASSERT(ne0 % QK8_1 == 0); + const uint3 ne2_fastdiv = init_fastdiv_values(ne2); + const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); GGML_UNUSED(type_src0); } diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index 6bee204136b..6bcae9e52fb 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -39,7 +39,7 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r } __syncthreads(); sum = 0.0f; - if (lane_id < (blockDim.x / WARP_SIZE)) { + if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { sum = s_sum[lane_id]; } sum = warp_reduce_sum(sum); diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index 2ee9e588992..0ddeff6a175 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -1,18 +1,19 @@ #include "scale.cuh" -static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; +#define MAX_GRIDDIM_X 0x7FFFFFFF - if (i >= k) { - return; - } +static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) { + int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x; + int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x; - dst[i] = scale * x[i] + bias; + for (int64_t i = tid; i < nelements; i += stride) { + dst[i] = scale * x[i] + bias; + } } -static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, bias, k); +static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) { + const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; + scale_f32<<>>(x, dst, scale, bias, nelements); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index dc9a7d58d05..6b424381df5 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -129,7 +129,7 @@ __global__ void __launch_bounds__(d_state, 1) const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float); const int seq_idx = blockIdx.y; - const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float); + const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float)); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 3428113dc8f..da2d7b7c3b3 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -24,7 +24,7 @@ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", - "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS" + "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4" ] SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. @@ -34,6 +34,13 @@ DECL_MMQ_CASE({type}); """ +SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE({type}); +""" + def get_short_name(long_quant_name): return long_quant_name.replace("GGML_TYPE_", "").lower() @@ -76,3 +83,7 @@ def get_head_sizes(type_k, type_v): for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: f.write(SOURCE_MMQ.format(type=type)) + +for type in range(1, 17): + with open(f"mmf-instance-ncols_{type}.cu", "w") as f: + f.write(SOURCE_MMF.format(type=type)) diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu new file mode 100644 index 00000000000..f594d5d51d2 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(1); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu new file mode 100644 index 00000000000..9cc67725421 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(10); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu new file mode 100644 index 00000000000..317f487d7a7 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(11); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu new file mode 100644 index 00000000000..dc0033227c0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(12); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu new file mode 100644 index 00000000000..07821017530 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(13); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu new file mode 100644 index 00000000000..a23ad6ae262 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(14); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu new file mode 100644 index 00000000000..0fe3f7821ee --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(15); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu new file mode 100644 index 00000000000..544086375e8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(16); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu new file mode 100644 index 00000000000..3b901797cfb --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(2); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu new file mode 100644 index 00000000000..56e940bba08 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(3); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu new file mode 100644 index 00000000000..a7665d49d0b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(4); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu new file mode 100644 index 00000000000..3a1dff2587a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(5); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu new file mode 100644 index 00000000000..400fb7c6631 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(6); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu new file mode 100644 index 00000000000..954a1c7e032 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(7); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu new file mode 100644 index 00000000000..f1bd09c9458 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(8); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu new file mode 100644 index 00000000000..1255ac2af66 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(9); diff --git a/ggml/src/ggml-cuda/tsembd.cu b/ggml/src/ggml-cuda/tsembd.cu index 153ddbcda92..b91a26fc80e 100644 --- a/ggml/src/ggml-cuda/tsembd.cu +++ b/ggml/src/ggml-cuda/tsembd.cu @@ -7,11 +7,11 @@ static __global__ void timestep_embedding_f32(const float * timesteps, float * d int j = threadIdx.x + blockIdx.x * blockDim.x; float * embed_data = (float *)((char *)dst + i*nb1); - if (dim % 2 != 0 && j == ((dim + 1) / 2)) { - embed_data[dim] = 0.f; + int half = dim / 2; + if (dim % 2 != 0 && j == half) { + embed_data[2 * half] = 0.f; } - int half = dim / 2; if (j >= half) { return; } diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index d8f9aa5ba62..6baab1176ff 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -28,7 +28,58 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32 return ((const int *) x)[i32]; // assume at least 4 byte alignment } +// q4 contains 8 indices with 4 bit each. +// This function selects those bytes from table that are at those indices and returns them as int2. +// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4. static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) { +#if defined(GGML_USE_HIP) + // Load the 16-byte table into four 32-bit unsigned integers. + const uint32_t *values = (const uint32_t *)table; + + const uint32_t q_even = q4; + const uint32_t q_odd = (q4 >> 4); + + // Perform lookups in the lower half of the table (indices 0-7). + uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707); + uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707); + + // Perform lookups in the upper half of the table (indices 8-15). + uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707); + uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707); + + // Select between the low and high results based on the MSB of each index nibble. + uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1); + uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even); + uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1); + uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd); + + return make_int2(res_x, res_y); +#elif !defined(GGML_USE_MUSA) + // CUDA does not have an instruction for selecting bytes with 4 bit indices. + // However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead. + const uint32_t * table32 = (const uint32_t *) table; + + // __byte_perm selects bytes based on the lower 16 bits in its third argument. + // Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift. + // To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits. + // Then, call __byte_perm again to select from the low and high bytes based on the fourth bit. + uint32_t tmp[2]; + const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1)); +#pragma unroll + for (uint32_t i = 0; i < 2; ++i) { + const uint32_t shift = 16 * i; + + const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift); + const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift); + tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift); + } + + // tmp contains the bytes from tyble in the same order as the 4 bit indices in q4. + // However, for the result we need ints with all even/odd 4 bit indices in q4. + // Therefore, 2 more calls to __byte_perm to put the bytes in the correct order. + return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531)); +#else + // Generic implementation. const int q0_32 = (q4 >> 0) & 0x0F0F0F0F; const int8_t * q0_8 = (const int8_t *) &q0_32; const char4 val0_8 = make_char4( @@ -40,6 +91,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]); return make_int2(*((const int *) &val0_8), *((const int *) &val1_8)); +#endif } // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called @@ -87,7 +139,7 @@ template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); } -#ifdef GGML_CUDA_F16 +#ifdef FAST_FP16_AVAILABLE const float2 tmp = __half22float2(__hmul2(dm4, ds8)); const float d4d8 = tmp.x; const float m4s8 = tmp.y; @@ -96,7 +148,7 @@ template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp const float2 ds8f = __half22float2(ds8); const float d4d8 = dm4f.x * ds8f.x; const float m4s8 = dm4f.y * ds8f.y; -#endif // GGML_CUDA_F16 +#endif // FAST_FP16_AVAILABLE // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); @@ -158,7 +210,7 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values } -#ifdef GGML_CUDA_F16 +#ifdef FAST_FP16_AVAILABLE const float2 tmp = __half22float2(__hmul2(dm5, ds8)); const float d5d8 = tmp.x; const float m5s8 = tmp.y; @@ -167,7 +219,7 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp const float2 ds8f = __half22float2(ds8); const float d5d8 = dm5f.x * ds8f.x; const float m5s8 = dm5f.y * ds8f.y; -#endif // GGML_CUDA_F16 +#endif // FAST_FP16_AVAILABLE // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it return sumi*d5d8 + m5s8 / (QI5_1 / vdr); @@ -201,7 +253,7 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp sumi = ggml_cuda_dp4a(v[i], u[i], sumi); } -#ifdef GGML_CUDA_F16 +#ifdef FAST_FP16_AVAILABLE const float2 tmp = __half22float2(__hmul2(dm8, ds8)); const float d8d8 = tmp.x; const float m8s8 = tmp.y; @@ -210,7 +262,7 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp const float2 ds8f = __half22float2(ds8); const float d8d8 = dm8f.x * ds8f.x; const float m8s8 = dm8f.y * ds8f.y; -#endif // GGML_CUDA_F16 +#endif // FAST_FP16_AVAILABLE // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it return sumi*d8d8 + m8s8 / (QI8_1 / vdr); diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 6e9c67aca09..37386afcd40 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -22,7 +22,10 @@ #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) +#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#define __all_sync(mask, var) __all(var) +#define __any_sync(mask, var) __any(var) #define cublasCreate hipblasCreate #define cublasDestroy hipblasDestroy #define cublasGemmEx hipblasGemmEx @@ -155,33 +158,41 @@ #define __CUDA_ARCH__ 1300 -#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) -#define GCN -#endif +#if defined(__gfx900__) || defined(__gfx906__) +#define GCN5 +#endif // defined(__gfx900__) || defined(__gfx906__) -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) -#define CDNA // For the entire family -#endif +#if defined(__gfx803__) +#define GCN4 +#endif // defined(__gfx803__) + +#if defined(GCN5) || defined(GCN4) +#define GCN +#endif // defined(GCN5) || defined(GCN4) #if defined(__gfx942__) #define CDNA3 -#endif +#endif // defined(__gfx942__) #if defined(__gfx90a__) #define CDNA2 -#endif +#endif // defined(__gfx90a__) #if defined(__gfx908__) #define CDNA1 -#endif +#endif // defined(__gfx908__) + +#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#define CDNA // For the entire family +#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1) #if defined(__GFX12__) #define RDNA4 -#endif +#endif // defined(__GFX12__) #if defined(__GFX11__) #define RDNA3 -#endif +#endif // defined(__GFX11__) #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) @@ -190,7 +201,11 @@ #if defined(__gfx1010__) || defined(__gfx1012__) #define RDNA1 -#endif +#endif // defined(__gfx1010__) || defined(__gfx1012__) + +#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1) +#define RDNA // For the entire family +#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1) #ifndef __has_builtin #define __has_builtin(x) 0 diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index 0ca8a3c55ec..63418fe1430 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -5,7 +5,12 @@ find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) message(STATUS "Metal framework found") ggml_add_backend_library(ggml-metal - ggml-metal.m + ggml-metal.cpp + ggml-metal-device.m + ggml-metal-device.cpp + ggml-metal-common.cpp + ggml-metal-context.m + ggml-metal-ops.cpp ) target_link_libraries(ggml-metal PRIVATE @@ -18,10 +23,6 @@ if (GGML_METAL_NDEBUG) add_compile_definitions(GGML_METAL_NDEBUG) endif() -if (GGML_METAL_USE_BF16) - add_compile_definitions(GGML_METAL_USE_BF16) -endif() - # copy metal files to bin directory configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp new file mode 100644 index 00000000000..2ffec26619a --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -0,0 +1,460 @@ +#include "ggml-metal-common.h" + +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include + +// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb) +// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it) +struct ggml_mem_range { + uint64_t pb; // buffer id + + uint64_t p0; // begin + uint64_t p1; // end + + ggml_mem_range_type pt; +}; + +struct ggml_mem_ranges { + std::vector ranges; + + int debug = 0; +}; + +ggml_mem_ranges_t ggml_mem_ranges_init(int debug) { + auto * res = new ggml_mem_ranges; + + res->ranges.reserve(256); + res->debug = debug; + + return res; +} + +void ggml_mem_ranges_free(ggml_mem_ranges_t mrs) { + delete mrs; +} + +void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs) { + mrs->ranges.clear(); +} + +static bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, ggml_mem_range mr) { + mrs->ranges.push_back(mr); + + return true; +} + +static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) { + // always use the base tensor + tensor = tensor->view_src ? tensor->view_src : tensor; + + GGML_ASSERT(!tensor->view_src); + + ggml_mem_range mr; + + if (tensor->buffer) { + // when the tensor is allocated, use the actual memory address range in the buffer + // + // take the actual allocated size with ggml_backend_buft_get_alloc_size() + // this can be larger than the tensor size if the buffer type allocates extra memory + // ref: https://github.com/ggml-org/llama.cpp/pull/15966 + mr = { + /*.pb =*/ (uint64_t) tensor->buffer, + /*.p0 =*/ (uint64_t) tensor->data, + /*.p1 =*/ (uint64_t) tensor->data + ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor), + /*.pt =*/ pt, + }; + } else { + // otherwise, the pointer address is used as an unique id of the memory ranges + // that the tensor will be using when it is allocated + mr = { + /*.pb =*/ (uint64_t) tensor, + /*.p0 =*/ 0, // + /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used + /*.pt =*/ pt, + }; + }; + + return mr; +} + +static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) { + return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC); +} + +static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) { + return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST); +} + +static bool ggml_mem_ranges_add_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { + GGML_ASSERT(tensor); + + ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor); + + if (mrs->debug > 2) { + GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1); + } + + return ggml_mem_ranges_add(mrs, mr); +} + +static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { + GGML_ASSERT(tensor); + + ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor); + + if (mrs->debug > 2) { + GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1); + } + + return ggml_mem_ranges_add(mrs, mr); +} + +bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (tensor->src[i]) { + ggml_mem_ranges_add_src(mrs, tensor->src[i]); + } + } + + return ggml_mem_ranges_add_dst(mrs, tensor); +} + +static bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, ggml_mem_range mr) { + for (size_t i = 0; i < mrs->ranges.size(); i++) { + const auto & cmp = mrs->ranges[i]; + + // two memory ranges cannot intersect if they are in different buffers + if (mr.pb != cmp.pb) { + continue; + } + + // intersecting source ranges are allowed + if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) { + continue; + } + + if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) { + if (mrs->debug > 2) { + GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n", + __func__, + mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst", + mr.pb, mr.p0, mr.p1, + cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst", + cmp.pb, cmp.p0, cmp.p1); + } + + return false; + } + } + + return true; +} + +static bool ggml_mem_ranges_check_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { + GGML_ASSERT(tensor); + + ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor); + + const bool res = ggml_mem_ranges_check(mrs, mr); + + return res; +} + +static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { + GGML_ASSERT(tensor); + + ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor); + + const bool res = ggml_mem_ranges_check(mrs, mr); + + return res; +} + +bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (tensor->src[i]) { + if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) { + return false; + } + } + } + + return ggml_mem_ranges_check_dst(mrs, tensor); +} + +// TODO: move to ggml.h? +static bool is_empty(ggml_op op) { + switch (op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_TRANSPOSE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + return true; + default: + return false; + } +} + +struct node_info { + ggml_tensor * node; + + std::vector fused; + + ggml_op op() const { + return node->op; + } + + const ggml_tensor * dst() const { + return fused.empty() ? node : fused.back(); + } + + bool is_empty() const { + return ::is_empty(node->op); + } + + void add_fused(ggml_tensor * t) { + fused.push_back(t); + } +}; + +static std::vector ggml_metal_graph_optimize_reorder(const std::vector & nodes) { + // helper to add node src and dst ranges + const auto & h_add = [](ggml_mem_ranges_t mrs, const node_info & node) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node.node->src[i]) { + if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) { + return false; + } + } + } + + // keep track of the sources of the fused nodes as well + for (const auto * fused : node.fused) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (fused->src[i]) { + if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) { + return false; + } + } + } + } + + return ggml_mem_ranges_add_dst(mrs, node.dst()); + }; + + // helper to check if a node can run concurrently with the existing set of nodes + const auto & h_check = [](ggml_mem_ranges_t mrs, const node_info & node) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node.node->src[i]) { + if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) { + return false; + } + } + } + + for (const auto * fused : node.fused) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (fused->src[i]) { + if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) { + return false; + } + } + } + } + + return ggml_mem_ranges_check_dst(mrs, node.dst()); + }; + + // perform reorders only across these types of ops + // can be expanded when needed + const auto & h_safe = [](ggml_op op) { + switch (op) { + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ROPE: + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MUL: + case GGML_OP_ADD: + case GGML_OP_DIV: + case GGML_OP_GLU: + case GGML_OP_SCALE: + case GGML_OP_GET_ROWS: + case GGML_OP_CPY: + case GGML_OP_SET_ROWS: + return true; + default: + return is_empty(op); + } + }; + + const int n = nodes.size(); + + std::vector res; + res.reserve(n); + + std::vector used(n, false); + + // the memory ranges for the set of currently concurrent nodes + ggml_mem_ranges_t mrs0 = ggml_mem_ranges_init(0); + + // the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder + ggml_mem_ranges_t mrs1 = ggml_mem_ranges_init(0); + + for (int i0 = 0; i0 < n; i0++) { + if (used[i0]) { + continue; + } + + const auto & node0 = nodes[i0]; + + // the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0) + // but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0 + // + // note: we can always add empty nodes to the concurrent set as they don't read nor write anything + if (!node0.is_empty() && !h_check(mrs0, node0)) { + // this will hold the set of memory ranges from the nodes that haven't been processed yet + // if a node is not concurrent with this set, we cannot reorder it + ggml_mem_ranges_reset(mrs1); + + // initialize it with the current node + h_add(mrs1, node0); + + // that many nodes forward to search for a concurrent node + constexpr int N_FORWARD = 8; + + for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) { + if (used[i1]) { + continue; + } + + const auto & node1 = nodes[i1]; + + // disallow reordering of certain ops + if (!h_safe(node1.op())) { + break; + } + + const bool is_empty = node1.is_empty(); + + // to reorder a node and add it to the concurrent set, it has to be: + // + empty or concurrent with all nodes in the existing concurrent set (mrs0) + // + concurrent with all nodes prior to it that haven't been processed yet (mrs1) + if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) { + // add the node to the existing concurrent set (i.e. reorder it for early execution) + h_add(mrs0, node1); + res.push_back(i1); + + // mark as used, so we skip re-processing it later + used[i1] = true; + } else { + // expand the set of nodes that haven't been processed yet + h_add(mrs1, node1); + } + } + + // finalize the concurrent set and begin a new one + ggml_mem_ranges_reset(mrs0); + } + + // expand the concurrent set with the current node + { + h_add(mrs0, node0); + res.push_back(i0); + } + } + + ggml_mem_ranges_free(mrs0); + ggml_mem_ranges_free(mrs1); + + return res; +} + +void ggml_graph_optimize(ggml_cgraph * gf) { + constexpr int MAX_FUSE = 16; + + const int n = gf->n_nodes; + + enum ggml_op ops[MAX_FUSE]; + + std::vector nodes; + nodes.reserve(gf->n_nodes); + + // fuse nodes: + // we don't want to make reorders that break fusing, so we first pack all fusable tensors + // and perform the reorder over the fused nodes. after the reorder is done, we unfuse + for (int i = 0; i < n; i++) { + node_info node = { + /*.node =*/ gf->nodes[i], + /*.fused =*/ {}, + }; + + // fuse only ops that start with these operations + // can be expanded when needed + if (node.op() == GGML_OP_ADD || + node.op() == GGML_OP_NORM || + node.op() == GGML_OP_RMS_NORM) { + ops[0] = node.op(); + + int f = i + 1; + while (f < n && f < i + MAX_FUSE) { + // conservatively allow fusing only these ops + // can be expanded when needed + if (gf->nodes[f]->op != GGML_OP_ADD && + gf->nodes[f]->op != GGML_OP_MUL && + gf->nodes[f]->op != GGML_OP_NORM && + gf->nodes[f]->op != GGML_OP_RMS_NORM) { + break; + } + ops[f - i] = gf->nodes[f]->op; + f++; + } + + f -= i; + for (; f > 1; f--) { + if (ggml_can_fuse(gf, i, ops, f)) { + break; + } + } + + // add the fused tensors into the node info so we can unfuse them later + for (int k = 1; k < f; k++) { + ++i; + + // the .dst() becomes the last fused tensor + node.add_fused(gf->nodes[i]); + } + } + + nodes.push_back(std::move(node)); + } + +#if 1 + // reorder to improve concurrency + const auto order = ggml_metal_graph_optimize_reorder(nodes); +#else + std::vector order(nodes.size()); + for (size_t i = 0; i < nodes.size(); i++) { + order[i] = i; + } +#endif + + // unfuse + { + int j = 0; + for (const auto i : order) { + const auto & node = nodes[i]; + + gf->nodes[j++] = node.node; + + for (auto * fused : node.fused) { + gf->nodes[j++] = fused; + } + } + } +} diff --git a/ggml/src/ggml-metal/ggml-metal-common.h b/ggml/src/ggml-metal/ggml-metal-common.h new file mode 100644 index 00000000000..3acbc6ae174 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-common.h @@ -0,0 +1,52 @@ +// helper functions for ggml-metal that are too difficult to implement in Objective-C + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_tensor; +struct ggml_cgraph; + +enum ggml_mem_range_type { + MEM_RANGE_TYPE_SRC = 0, + MEM_RANGE_TYPE_DST = 1, +}; + +// a helper object that can be used for reordering operations to improve concurrency +// +// the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they +// don't write to a memory that is being read by another task or written to by another task in the set +// +// with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task +// can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the +// tasks already in the set) +// +typedef struct ggml_mem_ranges * ggml_mem_ranges_t; + +ggml_mem_ranges_t ggml_mem_ranges_init(int debug); +void ggml_mem_ranges_free(ggml_mem_ranges_t mrs); + +// remove all ranges from the set +void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs); + +// add src or dst ranges to track +bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor); + +// return false if: +// - new src range overlaps with any existing dst range +// - new dst range overlaps with any existing range (src or dst) +bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor); + +// reorder the nodes in the graph to improve concurrency, while respecting fusion +// +// note: this implementation is generic and not specific to metal +// if it proves to work well, we can start using it for other backends in the future +void ggml_graph_optimize(struct ggml_cgraph * gf); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-metal/ggml-metal-context.h b/ggml/src/ggml-metal/ggml-metal-context.h new file mode 100644 index 00000000000..ec2b686b733 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-context.h @@ -0,0 +1,33 @@ +#pragma once + +#include "ggml-metal-device.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// +// backend context +// + +typedef struct ggml_metal * ggml_metal_t; + +ggml_metal_t ggml_metal_init(ggml_metal_device_t dev); +void ggml_metal_free(ggml_metal_t ctx); + +void ggml_metal_synchronize(ggml_metal_t ctx); + +void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); +void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + +enum ggml_status ggml_metal_graph_compute (ggml_metal_t ctx, struct ggml_cgraph * gf); +void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf); + +void ggml_metal_set_n_cb (ggml_metal_t ctx, int n_cb); +void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data); +bool ggml_metal_supports_family (ggml_metal_t ctx, int family); +void ggml_metal_capture_next_compute(ggml_metal_t ctx); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m new file mode 100644 index 00000000000..af9ff214360 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -0,0 +1,575 @@ +#import "ggml-metal-context.h" + +#import "ggml-impl.h" +#import "ggml-backend-impl.h" + +#import "ggml-metal-impl.h" +#import "ggml-metal-common.h" +#import "ggml-metal-ops.h" + +#import + +#import + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// max number of MTLCommandBuffer used to submit a graph for processing +#define GGML_METAL_MAX_COMMAND_BUFFERS 8 + +struct ggml_metal_command_buffer { + id obj; +}; + +struct ggml_metal { + id device; + id queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND] + + ggml_metal_device_t dev; + ggml_metal_library_t lib; + + dispatch_queue_t d_queue; + + // additional, inference-time compiled pipelines + ggml_metal_pipelines_t pipelines_ext; + + bool use_bfloat; + bool use_fusion; + bool use_concurrency; + bool use_graph_optimize; + + int debug_graph; + int debug_fusion; + + // how many times a given op was fused + uint64_t fuse_cnt[GGML_OP_COUNT]; + + // capture state + bool capture_next_compute; + bool capture_started; + + id capture_scope; + + // command buffer state + int n_cb; // number of extra threads used to submit the command buffers + int n_nodes_0; // number of nodes submitted by the main thread + int n_nodes_1; // remaining number of nodes submitted by the n_cb threads + int n_nodes_per_cb; + + struct ggml_cgraph * gf; + + // the callback given to the thread pool + void (^encode_async)(size_t ith); + + // n_cb command buffers + 1 used by the main thread + struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + + // extra command buffers for things like getting, setting and copying tensors + NSMutableArray * cmd_bufs_ext; + + // the last command buffer queued into the Metal queue with operations relevant to the current Metal backend + id cmd_buf_last; + + // abort ggml_metal_graph_compute if callback returns true + ggml_abort_callback abort_callback; + void * abort_callback_data; +}; + +ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { + GGML_LOG_INFO("%s: allocating\n", __func__); + +#if TARGET_OS_OSX && !GGML_METAL_NDEBUG + // Show all the Metal device instances in the system + NSArray * devices = MTLCopyAllDevices(); + for (id device in devices) { + GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]); + } + [devices release]; // since it was created by a *Copy* C method +#endif + + // init context + ggml_metal_t res = calloc(1, sizeof(struct ggml_metal)); + + res->device = ggml_metal_device_get_obj(dev); + + GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[res->device name] UTF8String]); + + // TODO: would it be better to have one queue for the backend and one queue for the device? + // the graph encoders and async ops would use the backend queue while the sync ops would use the device queue? + //res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND] + res->queue = ggml_metal_device_get_queue(dev); + if (res->queue == nil) { + GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); + return NULL; + } + + res->dev = dev; + res->lib = ggml_metal_device_get_library(dev); + if (res->lib == NULL) { + GGML_LOG_WARN("%s: the device does not have a precompiled Metal library - this is unexpected\n", __func__); + GGML_LOG_WARN("%s: will try to compile it on the fly\n", __func__); + + res->lib = ggml_metal_library_init(dev); + if (res->lib == NULL) { + GGML_LOG_ERROR("%s: error: failed to initialize the Metal library\n", __func__); + + free(res); + + return NULL; + } + } + + const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + + res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); + + res->use_bfloat = props_dev->has_bfloat; + res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil; + res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil; + + { + const char * val = getenv("GGML_METAL_GRAPH_DEBUG"); + res->debug_graph = val ? atoi(val) : 0; + } + + { + const char * val = getenv("GGML_METAL_FUSION_DEBUG"); + res->debug_fusion = val ? atoi(val) : 0; + } + + res->use_graph_optimize = true; + + if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) { + res->use_graph_optimize = false; + } + + memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt)); + + GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, res->use_bfloat ? "true" : "false"); + GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false"); + GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false"); + GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false"); + + res->capture_next_compute = false; + res->capture_started = false; + res->capture_scope = nil; + + res->gf = nil; + res->encode_async = nil; + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + res->cmd_bufs[i].obj = nil; + } + + res->cmd_bufs_ext = [[NSMutableArray alloc] init]; + + res->cmd_buf_last = nil; + + res->pipelines_ext = ggml_metal_pipelines_init(); + + return res; +} + +void ggml_metal_free(ggml_metal_t ctx) { + GGML_LOG_INFO("%s: deallocating\n", __func__); + + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + if (ctx->cmd_bufs[i].obj) { + [ctx->cmd_bufs[i].obj release]; + } + } + + for (int i = 0; i < (int) ctx->cmd_bufs_ext.count; ++i) { + if (ctx->cmd_bufs_ext[i]) { + [ctx->cmd_bufs_ext[i] release]; + } + } + + [ctx->cmd_bufs_ext removeAllObjects]; + [ctx->cmd_bufs_ext release]; + + if (ctx->pipelines_ext) { + ggml_metal_pipelines_free(ctx->pipelines_ext); + ctx->pipelines_ext = nil; + } + + if (ctx->debug_fusion > 0) { + GGML_LOG_DEBUG("%s: fusion stats:\n", __func__); + for (int i = 0; i < GGML_OP_COUNT; i++) { + if (ctx->fuse_cnt[i] == 0) { + continue; + } + + // note: cannot use ggml_log here + GGML_LOG_DEBUG("%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]); + } + } + + Block_release(ctx->encode_async); + + //[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND] + + dispatch_release(ctx->d_queue); + + free(ctx); +} + +void ggml_metal_synchronize(ggml_metal_t ctx) { + // wait for any backend operations to finish + if (ctx->cmd_buf_last) { + [ctx->cmd_buf_last waitUntilCompleted]; + ctx->cmd_buf_last = nil; + } + + // release any completed command buffers + if (ctx->cmd_bufs_ext.count > 0) { + for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) { + id cmd_buf = ctx->cmd_bufs_ext[i]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + GGML_ABORT("fatal error"); + } + + [cmd_buf release]; + } + + [ctx->cmd_bufs_ext removeAllObjects]; + } +} + +static struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_tensor * t) { + if (!t) { + return (struct ggml_metal_buffer_id) { nil, 0 }; + } + + ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; + + return ggml_metal_buffer_get_id(buffer->context, t); +} + +void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + @autoreleasepool { + // wrap the source data into a Metal buffer + id buf_src = [ctx->device newBufferWithBytes:data + length:size + options:MTLResourceStorageModeShared]; + + struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(tensor); + if (bid_dst.metal == nil) { + GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); + } + + bid_dst.offs += offset; + + // queue the copy operation into the queue of the Metal context + // this will be queued at the end, after any currently ongoing GPU operations + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:0 + toBuffer:bid_dst.metal + destinationOffset:bid_dst.offs + size:size]; + + [encoder endEncoding]; + [cmd_buf commit]; + + // do not wait here for completion + //[cmd_buf waitUntilCompleted]; + + // instead, remember a reference to the command buffer and wait for it later if needed + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + @autoreleasepool { + id buf_dst = [ctx->device newBufferWithBytesNoCopy:data + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + + struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(tensor); + if (bid_src.metal == nil) { + GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); + } + + bid_src.offs += offset; + + // queue the copy operation into the queue of the Metal context + // this will be queued at the end, after any currently ongoing GPU operations + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:bid_src.metal + sourceOffset:bid_src.offs + toBuffer:buf_dst + destinationOffset:0 + size:size]; + + [encoder endEncoding]; + [cmd_buf commit]; + + // do not wait here for completion + //[cmd_buf waitUntilCompleted]; + + // instead, remember a reference to the command buffer and wait for it later if needed + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) { + // number of nodes encoded by the main thread (empirically determined) + const int n_main = 64; + + // number of threads in addition to the main thread + const int n_cb = ctx->n_cb; + + // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them + // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread + // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes + // each thread creates it's own command buffer and enqueues the ops in parallel + // + // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2 + + @autoreleasepool { + ctx->gf = gf; + + ctx->n_nodes_0 = MIN(n_main, gf->n_nodes); + ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0; + + ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; + + const bool use_capture = ctx->capture_next_compute; + if (use_capture) { + ctx->capture_next_compute = false; + + // make sure all previous computations have finished before starting the capture + if (ctx->cmd_buf_last) { + [ctx->cmd_buf_last waitUntilCompleted]; + ctx->cmd_buf_last = nil; + } + + if (!ctx->capture_started) { + // create capture scope + ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device]; + + MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; + descriptor.captureObject = ctx->capture_scope; + descriptor.destination = MTLCaptureDestinationGPUTraceDocument; + descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; + + NSError * error = nil; + if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { + GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); + } else { + [ctx->capture_scope beginScope]; + ctx->capture_started = true; + } + } + } + + // the main thread commits the first few commands immediately + // cmd_buf[n_cb] + { + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + [cmd_buf retain]; + + if (ctx->cmd_bufs[n_cb].obj) { + [ctx->cmd_bufs[n_cb].obj release]; + } + ctx->cmd_bufs[n_cb].obj = cmd_buf; + + [cmd_buf enqueue]; + + ctx->encode_async(n_cb); + } + + // remember the command buffer for the next iteration + ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj; + + // prepare the rest of the command buffers asynchronously (optional) + // cmd_buf[0.. n_cb) + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + [cmd_buf retain]; + + if (ctx->cmd_bufs[cb_idx].obj) { + [ctx->cmd_bufs[cb_idx].obj release]; + } + ctx->cmd_bufs[cb_idx].obj = cmd_buf; + + // always enqueue the first two command buffers + // enqueue all of the command buffers if we don't need to abort + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [cmd_buf enqueue]; + + // update the pointer to the last queued command buffer + // this is needed to implement synchronize() + ctx->cmd_buf_last = cmd_buf; + } + } + + dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); + + // for debugging: block until graph is computed + //[ctx->cmd_buf_last waitUntilCompleted]; + + // enter here only when capturing in order to wait for all computation to finish + // otherwise, we leave the graph to compute asynchronously + if (!use_capture && ctx->capture_started) { + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id cmd_buf = ctx->cmd_bufs[n_cb].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + } + + for (int i = 0; i < n_cb; ++i) { + id cmd_buf = ctx->cmd_bufs[i].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + + id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); + if (!next_buffer) { + continue; + } + + const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } + + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); + return GGML_STATUS_ABORTED; + } + + [next_buffer commit]; + } + + [ctx->capture_scope endScope]; + [[MTLCaptureManager sharedCaptureManager] stopCapture]; + } + } + + return GGML_STATUS_SUCCESS; +} + +void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) { + //const int64_t t_start = ggml_time_us(); + + if (ctx->use_graph_optimize) { + ggml_graph_optimize(gf); + } + + //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0); +} + +void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) { + if (ctx->n_cb != n_cb) { + ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); + + if (ctx->n_cb > 2) { + GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb); + } + } + + if (ctx->encode_async) { + Block_release(ctx->encode_async); + } + + ctx->encode_async = Block_copy(^(size_t iter) { + const int cb_idx = iter; + const int n_cb_l = ctx->n_cb; + + const int n_nodes_0 = ctx->n_nodes_0; + const int n_nodes_1 = ctx->n_nodes_1; + + const int n_nodes_per_cb = ctx->n_nodes_per_cb; + + int idx_start = 0; + int idx_end = n_nodes_0; + + if (cb_idx < n_cb_l) { + idx_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); + idx_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); + } + + id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + + ggml_metal_op_t ctx_op = ggml_metal_op_init( + ctx->dev, + cmd_buf, + ctx->gf, + idx_start, + idx_end, + ctx->use_fusion, + ctx->use_concurrency, + ctx->capture_next_compute, + ctx->debug_graph, + ctx->debug_fusion); + + for (int idx = idx_start; idx < idx_end;) { + const int res = ggml_metal_op_encode(ctx_op, idx); + if (res == 0) { + break; + } + + idx += res; + } + + ggml_metal_op_free(ctx_op); + + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [cmd_buf commit]; + } + }); +} + +void ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data) { + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = user_data; +} + +bool ggml_metal_supports_family(ggml_metal_t ctx, int family) { + GGML_ASSERT(ctx->device != nil); + + return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; +} + +void ggml_metal_capture_next_compute(ggml_metal_t ctx) { + ctx->capture_next_compute = true; +} diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp new file mode 100644 index 00000000000..6af1bd61680 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -0,0 +1,1378 @@ +#include "ggml-metal-device.h" + +#include "ggml-metal-impl.h" + +#include "ggml-impl.h" + +#include +#include +#include +#include + +struct ggml_metal_device_deleter { + void operator()(ggml_metal_device_t ctx) { + ggml_metal_device_free(ctx); + } +}; + +typedef std::unique_ptr ggml_metal_device_ptr; + +ggml_metal_device_t ggml_metal_device_get(void) { + static ggml_metal_device_ptr ctx { ggml_metal_device_init() }; + + return ctx.get(); +} + +struct ggml_metal_pipelines { + std::unordered_map data; +}; + +ggml_metal_pipelines_t ggml_metal_pipelines_init(void) { + ggml_metal_pipelines_t res = new ggml_metal_pipelines(); + + return res; +} + +void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) { + if (!ppls) { + return; + } + + for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) { + ggml_metal_pipeline_free(it->second); + } + + delete ppls; +} + +void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline) { + ppls->data[name] = pipeline; +} + +ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) { + if (ppls->data.find(name) == ppls->data.end()) { + return nullptr; + } + + return ppls->data[name]; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) { + char base[256]; + char name[256]; + + const char * op_str = "undefined"; + switch (op) { + case GGML_OP_ADD_ID: op_str = "add_id"; break; + case GGML_OP_CONCAT: op_str = "concat"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_%s", op_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); + + const char * pool_str = "undefined"; + switch (op_pool) { + case GGML_OP_POOL_AVG: pool_str = "avg"; break; + case GGML_OP_POOL_MAX: pool_str = "max"; break; + default: GGML_ASSERT(false && "not implemented"); + }; + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tdst) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_set_rows_%s", ggml_type_name(tdst)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + + char base[256]; + char name[256]; + + const int64_t n = ggml_nelements(op); + + const char * op_str = "undefined"; + switch (op->op) { + case GGML_OP_SCALE: op_str = "scale"; break; + case GGML_OP_CLAMP: op_str = "clamp"; break; + case GGML_OP_SQR: op_str = "sqr"; break; + case GGML_OP_SQRT: op_str = "sqrt"; break; + case GGML_OP_SIN: op_str = "sin"; break; + case GGML_OP_COS: op_str = "cos"; break; + case GGML_OP_LOG: op_str = "log"; break; + case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_TANH: op_str = "tanh"; break; + case GGML_UNARY_OP_RELU: op_str = "relu"; break; + case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break; + case GGML_UNARY_OP_GELU: op_str = "gelu"; break; + case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break; + case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break; + case GGML_UNARY_OP_SILU: op_str = "silu"; break; + case GGML_UNARY_OP_ELU: op_str = "elu"; break; + case GGML_UNARY_OP_NEG: op_str = "neg"; break; + case GGML_UNARY_OP_ABS: op_str = "abs"; break; + case GGML_UNARY_OP_SGN: op_str = "sgn"; break; + case GGML_UNARY_OP_STEP: op_str = "step"; break; + case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break; + case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break; + case GGML_UNARY_OP_EXP: op_str = "exp"; break; + default: GGML_ABORT("fatal error"); + } break; + default: GGML_ABORT("fatal error"); + }; + + const char * suffix = ""; + if (n % 4 == 0) { + suffix = "_4"; + } + + snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); + + char base[256]; + char name[256]; + + const char * op_str = "undefined"; + switch (op->op) { + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: op_str = "reglu"; break; + case GGML_GLU_OP_GEGLU: op_str = "geglu"; break; + case GGML_GLU_OP_SWIGLU: op_str = "swiglu"; break; + case GGML_GLU_OP_SWIGLU_OAI: op_str = "swiglu_oai"; break; + case GGML_GLU_OP_GEGLU_ERF: op_str = "geglu_erf"; break; + case GGML_GLU_OP_GEGLU_QUICK: op_str = "geglu_quick"; break; + default: GGML_ABORT("fatal error"); + } break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + + char base[256]; + char name[256]; + + const char * op_str = "undefined"; + switch (op->op) { + case GGML_OP_SUM_ROWS: + op_str = "sum_rows"; break; + case GGML_OP_MEAN: + op_str = "mean"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + const char * suffix = ""; + + if (op->src[0]->ne[0] % 4 == 0) { + suffix = "_4"; + } + + const ggml_type tsrc1 = op->src[1] ? op->src[1]->type : GGML_TYPE_F32; + + snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + if (op->src[3]->ne[0] == 1) { + snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type)); + } else { + snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); + } + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + const int64_t C = op->ne[0]; + const int64_t H = op->src[0]->ne[1]; + + switch (op->op) { + case GGML_OP_RWKV_WKV6: + { + GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64); + + snprintf(base, 256, "kernel_rwkv_wkv6_%s", ggml_type_name(op->src[0]->type)); + } break; + case GGML_OP_RWKV_WKV7: + { + GGML_ASSERT(op->src[6]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64); + + snprintf(base, 256, "kernel_rwkv_wkv7_%s", ggml_type_name(op->src[0]->type)); + } break; + default: + GGML_ABORT("fatal error"); + } + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); + snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 8192); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + + char base[256]; + char name[256]; + + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + + const char * suffix = ""; + + // use custom matrix x vector kernel + switch (tsrc0) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + { + if (ne00 == 4) { + nsg = 1; + nr0 = 32; + nr1 = 4; + suffix = "_c4"; + } else if (ne00 % 4 == 0) { + nsg = N_SG_F; + nr0 = N_R0_F; + nr1 = 1; + smem = 32*sizeof(float)*N_R0_F; + suffix = "_4"; + } else { + nsg = N_SG_F; + nr0 = N_R0_F; + nr1 = 1; + smem = 32*sizeof(float)*N_R0_F; + } + } break; + case GGML_TYPE_Q4_0: + { + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; + } break; + case GGML_TYPE_Q4_1: + { + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; + } break; + case GGML_TYPE_Q5_0: + { + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; + } break; + case GGML_TYPE_Q5_1: + { + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; + } break; + case GGML_TYPE_Q8_0: + { + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; + smem = 32*sizeof(float)*N_R0_Q8_0; + } break; + case GGML_TYPE_MXFP4: + { + nsg = N_SG_MXFP4; + nr0 = N_R0_MXFP4; + smem = 32*sizeof(float); + } break; + case GGML_TYPE_Q2_K: + { + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; + } break; + case GGML_TYPE_Q3_K: + { + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; + } break; + case GGML_TYPE_Q4_K: + { + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; + } break; + case GGML_TYPE_Q5_K: + { + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; + } break; + case GGML_TYPE_Q6_K: + { + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; + } break; + case GGML_TYPE_IQ2_XXS: + { + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; + } break; + case GGML_TYPE_IQ2_XS: + { + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; + } break; + case GGML_TYPE_IQ3_XXS: + { + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; + } break; + case GGML_TYPE_IQ3_S: + { + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; + } break; + case GGML_TYPE_IQ2_S: + { + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; + } break; + case GGML_TYPE_IQ1_S: + { + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; + } break; + case GGML_TYPE_IQ1_M: + { + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; + } break; + case GGML_TYPE_IQ4_NL: + { + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); + } break; + case GGML_TYPE_IQ4_XS: + { + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); + } break; + default: + { + GGML_LOG_ERROR("Asserting on type %d\n", (int) tsrc0); + GGML_ABORT("not implemented"); + } + }; + + snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); + snprintf(name, 256, "%s_nsg=%d", base, nsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + ggml_metal_pipeline_set_nr0 (res, nr0); + ggml_metal_pipeline_set_nr1 (res, nr1); + ggml_metal_pipeline_set_nsg (res, nsg); + ggml_metal_pipeline_set_smem(res, smem); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t); + + ggml_metal_pipeline_set_smem(res, smem); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 8192); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + + char base[256]; + char name[256]; + + int nsg = 0; // number of simdgroups + int nr0 = 0; // number of src0 rows per simdgroup + int nr1 = 1; // number of src1 rows per threadgroup + + size_t smem = 0; // shared memory + + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + + const char * suffix = ""; + + // use custom matrix x vector kernel + switch (tsrc0) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + { + if (ne00 % 4 == 0) { + nsg = N_SG_F; + nr0 = N_R0_F; + nr1 = 1; + smem = 32*sizeof(float)*N_R0_F; + suffix = "_4"; + } else { + nsg = N_SG_F; + nr0 = N_R0_F; + nr1 = 1; + smem = 32*sizeof(float)*N_R0_F; + } + } break; + case GGML_TYPE_Q4_0: + { + nsg = N_SG_Q4_0; + nr0 = N_R0_Q4_0; + } break; + case GGML_TYPE_Q4_1: + { + nsg = N_SG_Q4_1; + nr0 = N_R0_Q4_1; + } break; + case GGML_TYPE_Q5_0: + { + nsg = N_SG_Q5_0; + nr0 = N_R0_Q5_0; + } break; + case GGML_TYPE_Q5_1: + { + nsg = N_SG_Q5_1; + nr0 = N_R0_Q5_1; + } break; + case GGML_TYPE_Q8_0: + { + nsg = N_SG_Q8_0; + nr0 = N_R0_Q8_0; + smem = 32*sizeof(float)*N_R0_Q8_0; + } break; + case GGML_TYPE_MXFP4: + { + nsg = N_SG_MXFP4; + nr0 = N_R0_MXFP4; + smem = 32*sizeof(float); + } break; + case GGML_TYPE_Q2_K: + { + nsg = N_SG_Q2_K; + nr0 = N_R0_Q2_K; + } break; + case GGML_TYPE_Q3_K: + { + nsg = N_SG_Q3_K; + nr0 = N_R0_Q3_K; + } break; + case GGML_TYPE_Q4_K: + { + nsg = N_SG_Q4_K; + nr0 = N_R0_Q4_K; + } break; + case GGML_TYPE_Q5_K: + { + nsg = N_SG_Q5_K; + nr0 = N_R0_Q5_K; + } break; + case GGML_TYPE_Q6_K: + { + nsg = N_SG_Q6_K; + nr0 = N_R0_Q6_K; + } break; + case GGML_TYPE_IQ2_XXS: + { + nsg = N_SG_IQ2_XXS; + nr0 = N_R0_IQ2_XXS; + smem = 256*8+128; + } break; + case GGML_TYPE_IQ2_XS: + { + nsg = N_SG_IQ2_XS; + nr0 = N_R0_IQ2_XS; + smem = 512*8+128; + } break; + case GGML_TYPE_IQ3_XXS: + { + nsg = N_SG_IQ3_XXS; + nr0 = N_R0_IQ3_XXS; + smem = 256*4+128; + } break; + case GGML_TYPE_IQ3_S: + { + nsg = N_SG_IQ3_S; + nr0 = N_R0_IQ3_S; + smem = 512*4; + } break; + case GGML_TYPE_IQ2_S: + { + nsg = N_SG_IQ2_S; + nr0 = N_R0_IQ2_S; + } break; + case GGML_TYPE_IQ1_S: + { + nsg = N_SG_IQ1_S; + nr0 = N_R0_IQ1_S; + } break; + case GGML_TYPE_IQ1_M: + { + nsg = N_SG_IQ1_M; + nr0 = N_R0_IQ1_M; + } break; + case GGML_TYPE_IQ4_NL: + { + nsg = N_SG_IQ4_NL; + nr0 = N_R0_IQ4_NL; + smem = 32*sizeof(float); + } break; + case GGML_TYPE_IQ4_XS: + { + nsg = N_SG_IQ4_XS; + nr0 = N_R0_IQ4_XS; + smem = 32*sizeof(float); + } break; + default: + { + GGML_LOG_ERROR("Asserting on type %d\n", (int)op->src[2]->type); + GGML_ABORT("not implemented"); + } + }; + + snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); + snprintf(name, 256, "%s_nsg=%d", base, nsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + ggml_metal_pipeline_set_nr0 (res, nr0); + ggml_metal_pipeline_set_nr1 (res, nr1); + ggml_metal_pipeline_set_nsg (res, nsg); + ggml_metal_pipeline_set_smem(res, smem); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); + GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*(sizeof(float) + sizeof(int32_t))); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_ARGSORT); + + char base[256]; + char name[256]; + + ggml_sort_order order = (ggml_sort_order) op->op_params[0]; + + const char * order_str = "undefined"; + switch (order) { + case GGML_SORT_ORDER_ASC: order_str = "asc"; break; + case GGML_SORT_ORDER_DESC: order_str = "desc"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( + ggml_metal_library_t lib, + const ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + int32_t nsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + char base[256]; + char name[256]; + + const int32_t dk = (int32_t) op->src[1]->ne[0]; + const int32_t dv = (int32_t) op->src[2]->ne[0]; + + const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; + const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", + "flash_attn_ext", + ggml_type_name(op->src[1]->type), + dk, + dv); + + snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d", + base, + has_mask, + has_sinks, + has_bias, + has_scap, + ns10, + ns20, + nsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0); + ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); + ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); + ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); + + ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); + ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); + ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( + ggml_metal_library_t lib, + const ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + int32_t nsg, + int32_t nwg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + char base[256]; + char name[256]; + + const int32_t dk = (int32_t) op->src[1]->ne[0]; + const int32_t dv = (int32_t) op->src[2]->ne[0]; + + const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; + const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", + "flash_attn_ext_vec", + ggml_type_name(op->src[1]->type), + dk, + dv); + + snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", + base, + has_mask, + has_sinks, + has_bias, + has_scap, + ns10, + ns20, + nsg, nwg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0); + ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); + ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); + ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); + + ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); + ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); + ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22); + ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( + ggml_metal_library_t lib, + const ggml_tensor * op, + int32_t dv, + int32_t nwg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce"); + snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0); + ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; + + GGML_UNUSED(op); +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin( + ggml_metal_library_t lib, + ggml_op op, + int32_t n_fuse, + bool row) { + char base[256]; + char name[256]; + + const char * op_str = "undefined"; + switch (op) { + case GGML_OP_ADD: op_str = "add"; break; + case GGML_OP_SUB: op_str = "sub"; break; + case GGML_OP_MUL: op_str = "mul"; break; + case GGML_OP_DIV: op_str = "div"; break; + default: GGML_ABORT("fatal error"); + }; + + if (row) { + snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse); + } else { + snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse); + } + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) { + assert(op->op == GGML_OP_RMS_NORM); + + GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + char base[256]; + char name[256]; + + switch (n_fuse) { + case 1: snprintf(base, 256, "kernel_rms_norm_f32"); break; + case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32"); break; + case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32"); break; + default: GGML_ABORT("fatal error"); + } + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_L2_NORM); + + GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_l2_norm_f32"); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_GROUP_NORM); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_group_norm_f32"); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) { + assert(op->op == GGML_OP_RMS_NORM); + + GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + char base[256]; + char name[256]; + + switch (n_fuse) { + case 1: snprintf(base, 256, "kernel_norm_f32"); break; + case 2: snprintf(base, 256, "kernel_norm_mul_f32"); break; + case 3: snprintf(base, 256, "kernel_norm_mul_add_f32"); break; + default: GGML_ABORT("fatal error"); + } + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_ROPE); + + char base[256]; + char name[256]; + + const int mode = ((const int32_t *) op->op_params)[2]; + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_neox) { + snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type)); + } else if (is_mrope && !is_vision) { + GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token + snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type)); + } else if (is_vision) { + GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token + snprintf(base, 256, "kernel_rope_vision_%s", ggml_type_name(op->src[0]->type)); + } else { + snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type)); + } + + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_IM2COL); + + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_CONV_TRANSPOSE_1D); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_UPSCALE); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_PAD); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_PAD_REFLECT_1D); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_ARANGE); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_TIMESTEP_EMBEDDING); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h new file mode 100644 index 00000000000..74c2e1e60de --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -0,0 +1,227 @@ +#pragma once + +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_metal_buffer_id { + void * metal; // id + size_t offs; +}; + +typedef struct ggml_metal_device * ggml_metal_device_t; + +// +// MTLFunctionConstantValues wrapper +// + +typedef struct ggml_metal_cv * ggml_metal_cv_t; + +ggml_metal_cv_t ggml_metal_cv_init(void); +void ggml_metal_cv_free(ggml_metal_cv_t cv); + +void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx); +void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx); +void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool value, int32_t idx); + +// +// MTLComputePipelineState wrapper +// + +typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t; + +ggml_metal_pipeline_t ggml_metal_pipeline_init(void); +void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline); + +void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg); +int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline); + +void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0); +int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline); + +void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1); +int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline); + +void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem); +size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline); + +int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline); + +// a collection of pipelines +typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t; + +ggml_metal_pipelines_t ggml_metal_pipelines_init(void); +void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls); + +void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline); +ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name); + +// +// MTLCommandBuffer wrapper +// + +typedef void * ggml_metal_cmd_buf_t; + +// +// MTLComputeCommandEncoder wrapper +// + +typedef struct ggml_metal_encoder * ggml_metal_encoder_t; + +ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent); +void ggml_metal_encoder_free(ggml_metal_encoder_t encoder); + +void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name); +void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder); + +void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline); + +void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx); +void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx); + +void ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx); + +void ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2); + +void ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder); + +void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder); + +// +// MTLLibrary wrapper +// + +typedef struct ggml_metal_library * ggml_metal_library_t; + +ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev); +void ggml_metal_library_free(ggml_metal_library_t lib); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); +ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tdst); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + int32_t nsg); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + int32_t nsg, + int32_t nwg); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t dv, + int32_t nwg); + +// +// device +// + +struct ggml_metal_device_props { + char name[128]; + + size_t max_buffer_size; + size_t max_working_set_size; + size_t max_theadgroup_memory_size; + + bool has_simdgroup_reduction; + bool has_simdgroup_mm; + bool has_unified_memory; + bool has_bfloat; + bool use_residency_sets; + bool use_shared_buffers; + + bool supports_gpu_family_apple7; +}; + +ggml_metal_device_t ggml_metal_device_init(void); +void ggml_metal_device_free(ggml_metal_device_t dev); + +// return a singleton that is automatically destroyed when the program exits +ggml_metal_device_t ggml_metal_device_get(void); + +void * ggml_metal_device_get_obj (ggml_metal_device_t dev); // id +void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id + +ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev); + +void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total); +bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op); + +const struct ggml_metal_device_props * ggml_metal_device_get_props(ggml_metal_device_t dev); + +// +// device buffers +// + +typedef struct ggml_metal_buffer * ggml_metal_buffer_t; + +ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared); +ggml_metal_buffer_t ggml_metal_buffer_map (ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size); + +void ggml_metal_buffer_free (ggml_metal_buffer_t buf); +void * ggml_metal_buffer_get_base (ggml_metal_buffer_t buf); +bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf); + +void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); +void ggml_metal_buffer_set_tensor (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); +void ggml_metal_buffer_get_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); +void ggml_metal_buffer_clear (ggml_metal_buffer_t buf, uint8_t value); + +// finds the Metal buffer that contains the tensor data on the GPU device +// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the +// Metal buffer based on the host memory pointer +// +struct ggml_metal_buffer_id ggml_metal_buffer_get_id(ggml_metal_buffer_t buf, const struct ggml_tensor * t); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m new file mode 100644 index 00000000000..67f71ace2b4 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -0,0 +1,1303 @@ +#import "ggml-metal-device.h" + +#import "ggml-impl.h" +#import "ggml-threading.h" + +#include + +#include + +#ifndef TARGET_OS_VISION +#define TARGET_OS_VISION 0 +#endif + +// create residency sets only on macOS >= 15.0 +#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \ + TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \ + TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \ + TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000 +#define GGML_METAL_HAS_RESIDENCY_SETS 1 +#endif + +// overload of MTLGPUFamilyMetal3 (not available in some environments) +static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; + +#if !GGML_METAL_EMBED_LIBRARY +// Here to assist with NSBundle Path Hack +@interface GGMLMetalClass : NSObject +@end +@implementation GGMLMetalClass +@end +#endif + +// +// MTLFunctionConstantValues wrapper +// + +struct ggml_metal_cv { + MTLFunctionConstantValues * obj; +}; + +ggml_metal_cv_t ggml_metal_cv_init(void) { + ggml_metal_cv_t res = calloc(1, sizeof(struct ggml_metal_cv)); + + res->obj = [[MTLFunctionConstantValues alloc] init]; + + return res; +} + +void ggml_metal_cv_free(ggml_metal_cv_t cv) { + [cv->obj release]; + free(cv); +} + +void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) { + [cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx]; +} + +void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) { + [cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx]; +} + +void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) { + [cv->obj setConstantValue:&value type:MTLDataTypeBool atIndex:idx]; +} + +// +// MTLComputePipelineState wrapper +// + +struct ggml_metal_pipeline { + id obj; + + // suggested dispatch sizes + int nsg; + + int nr0; + int nr1; + + size_t smem; +}; + +ggml_metal_pipeline_t ggml_metal_pipeline_init(void) { + ggml_metal_pipeline_t res = calloc(1, sizeof(struct ggml_metal_pipeline)); + + *res = (struct ggml_metal_pipeline) { + /*.obj =*/ nil, + /*.nsg =*/ 0, + /*.nr0 =*/ 0, + /*.nr1 =*/ 0, + /*.smem =*/ 0, + }; + + return res; +} + +void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) { + [pipeline->obj release]; + + free(pipeline); +} + +void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg) { + pipeline->nsg = nsg; +} + +int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline) { + return pipeline->nsg; +} + +void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0) { + pipeline->nr0 = nr0; +} + +int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline) { + return pipeline->nr0; +} + +void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1) { + pipeline->nr1 = nr1; +} + +int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline) { + return pipeline->nr1; +} + +void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem) { + pipeline->smem = smem; +} + +size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline) { + return pipeline->smem; +} + +int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline) { + return pipeline->obj.maxTotalThreadsPerThreadgroup; +} + +struct ggml_metal_library { + id obj; + id device; + + ggml_metal_pipelines_t pipelines; // cache of compiled pipelines +}; + +ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { + id library = nil; + id device = ggml_metal_device_get_obj(dev); + + // load library + // + // - first check if the library is embedded + // - then check if the library is in the bundle + // - if not found, load the source and compile it + // - if that fails, return NULL + // + // TODO: move to a function + { + const int64_t t_start = ggml_time_us(); + + NSError * error = nil; + NSString * src = nil; + +#if GGML_METAL_EMBED_LIBRARY + GGML_LOG_INFO("%s: using embedded metal library\n", __func__); + + extern const char ggml_metallib_start[]; + extern const char ggml_metallib_end[]; + + src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding]; +#else + +#ifdef SWIFT_PACKAGE + NSBundle * bundle = SWIFTPM_MODULE_BUNDLE; +#else + NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; +#endif + + NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"]; + if (path_lib == nil) { + // Try to find the resource in the directory where the current binary located. + NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0]; + NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent]; + + NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]]; + if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) { + GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]); + + NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error]; + if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) { + // Optionally, if this is a symlink, try to resolve it. + path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error]; + if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) { + // It is a relative path, adding the binary directory as directory prefix. + path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]]; + } + if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) { + // Link to the resource could not be resolved. + path_lib_default = nil; + } else { + GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]); + } + } + } else { + // The resource couldn't be found in the binary's directory. + path_lib_default = nil; + } + + path_lib = path_lib_default; + } + + if (path_lib != nil) { + // pre-compiled library found + NSURL * libURL = [NSURL fileURLWithPath:path_lib]; + GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]); + + library = [device newLibraryWithURL:libURL error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return nil; + } + } else { + GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); + + NSString * path_source; + NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"]; + + GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil"); + + if (path_resource) { + path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"]; + } else { + path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + } + + if (path_source == nil) { + GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); + path_source = @"ggml-metal.metal"; + } + + GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]); + + src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return nil; + } + } +#endif + + if (!library) { + @autoreleasepool { + // dictionary of preprocessor macros + NSMutableDictionary * prep = [NSMutableDictionary dictionary]; + + if (ggml_metal_device_get_props(dev)->has_bfloat) { + [prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"]; + } + +#if GGML_METAL_EMBED_LIBRARY + [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"]; +#endif + + MTLCompileOptions * options = [MTLCompileOptions new]; + options.preprocessorMacros = prep; + + //[options setFastMathEnabled:false]; + + library = [device newLibraryWithSource:src options:options error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return nil; + } + +#if !__has_feature(objc_arc) + [options release]; +#endif + } + } + +#if GGML_METAL_EMBED_LIBRARY + [src release]; +#endif // GGML_METAL_EMBED_LIBRARY + + GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6); + } + + ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); + + res->obj = library; + res->device = device; + res->pipelines = ggml_metal_pipelines_init(); + + return res; +} + +void ggml_metal_library_free(ggml_metal_library_t lib) { + if (!lib) { + return; + } + + if (lib->obj) { + [lib->obj release]; + } + + ggml_metal_pipelines_free(lib->pipelines); + + free(lib); +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { + return ggml_metal_pipelines_get(lib->pipelines, name); +} + +ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { + // note: the pipelines are cached in the library per device, so they are shared across all metal contexts + ggml_critical_section_start(); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + ggml_critical_section_end(); + + return res; + } + + res = ggml_metal_pipeline_init(); + + @autoreleasepool { + NSError * error = nil; + + NSString * base_func = [NSString stringWithUTF8String:base]; + + GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name); + + id mtl_function; + if (!cv) { + mtl_function = [lib->obj newFunctionWithName:base_func]; + } else { + mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error]; + } + if (!mtl_function) { + ggml_critical_section_end(); + + GGML_LOG_ERROR("%s: error: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name); + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + } + + return nil; + } + + res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; + + ggml_metal_pipelines_add(lib->pipelines, name, res); + + [mtl_function release]; + + GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj, + (int) res->obj.maxTotalThreadsPerThreadgroup, + (int) res->obj.threadExecutionWidth); + } + + ggml_critical_section_end(); + + return res; +} + +// +// MTLComputeCommandEncoder wrapper +// + +struct ggml_metal_encoder { + id obj; +}; + +ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent) { + ggml_metal_encoder_t res = calloc(1, sizeof(struct ggml_metal_encoder)); + + id cmd_buf = (id) cmd_buf_raw; + + if (concurrent) { + res->obj = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } else { + res->obj = [cmd_buf computeCommandEncoder]; + } + + [res->obj retain]; + + return res; +} + +void ggml_metal_encoder_free(ggml_metal_encoder_t encoder) { + [encoder->obj release]; + free(encoder); +} + +void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name) { + [encoder->obj pushDebugGroup:[NSString stringWithCString:name encoding:NSUTF8StringEncoding]]; +} + +void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) { + [encoder->obj popDebugGroup]; +} + +void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline) { + [encoder->obj setComputePipelineState:pipeline->obj]; +} + +void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) { + [encoder->obj setBytes:data length:size atIndex:idx]; +} + +void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx) { + [encoder->obj setBuffer:buffer.metal offset:buffer.offs atIndex:idx]; +} + +void ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx) { + [encoder->obj setThreadgroupMemoryLength:size atIndex:idx]; +} + +void ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2) { + [encoder->obj dispatchThreadgroups:MTLSizeMake(tg0, tg1, tg2) threadsPerThreadgroup:MTLSizeMake(tptg0, tptg1, tptg2)]; +} + +void ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder) { + [encoder->obj memoryBarrierWithScope:MTLBarrierScopeBuffers]; +} + +void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) { + [encoder->obj endEncoding]; +} + +struct ggml_metal_device { + id mtl_device; + + // a single global queue shared by all Metal backends + // technically not needed for devices with unified memory, but enables discrete GPUs support + // ref: https://github.com/ggml-org/llama.cpp/pull/15906 + id mtl_queue; + + ggml_metal_library_t library; + + struct ggml_metal_device_props props; +}; + +ggml_metal_device_t ggml_metal_device_init(void) { + ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device)); + + assert(dev != NULL); + + if (dev->mtl_device == nil) { + dev->mtl_device = MTLCreateSystemDefaultDevice(); + + if (dev->mtl_device) { + dev->mtl_queue = [dev->mtl_device newCommandQueue]; + if (dev->mtl_queue == nil) { + GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); + } + + dev->props.has_simdgroup_reduction = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; + dev->props.has_simdgroup_reduction |= [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; + + dev->props.has_simdgroup_mm = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; + dev->props.has_unified_memory = dev->mtl_device.hasUnifiedMemory; + + dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; + dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6]; + + dev->props.use_residency_sets = true; +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + dev->props.use_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil; +#endif + + dev->props.use_shared_buffers = dev->props.has_unified_memory; + + if (getenv("GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) { + dev->props.use_shared_buffers = false; + } + + dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; + + dev->props.max_buffer_size = dev->mtl_device.maxBufferLength; + dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize; + dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength; + + strncpy(dev->props.name, [[dev->mtl_device name] UTF8String], sizeof(dev->props.name) - 1); + + dev->library = ggml_metal_library_init(dev); + if (!dev->library) { + GGML_LOG_ERROR("%s: error: failed to create library\n", __func__); + } + + // -------------------------------------------------- + + // print MTL GPU family: + GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name); + + // determine max supported GPU family + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf + // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf + { + for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { + if ([dev->mtl_device supportsFamily:i]) { + GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); + break; + } + } + + for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) { + if ([dev->mtl_device supportsFamily:i]) { + GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i); + break; + } + } + + for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) { + if ([dev->mtl_device supportsFamily:i]) { + GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i); + break; + } + } + } + + GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, dev->props.has_simdgroup_reduction ? "true" : "false"); + GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, dev->props.has_simdgroup_mm ? "true" : "false"); + GGML_LOG_INFO("%s: has unified memory = %s\n", __func__, dev->props.has_unified_memory ? "true" : "false"); + GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, dev->props.has_bfloat ? "true" : "false"); + GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false"); + GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false"); + +#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) + if (@available(macOS 10.12, iOS 16.0, *)) { + GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, dev->props.max_working_set_size / 1e6); + } +#endif + } + } + + return dev; +} + +void ggml_metal_device_free(ggml_metal_device_t dev) { + assert(dev != NULL); + + ggml_metal_library_free(dev->library); + dev->library = NULL; + + if (dev->mtl_queue) { + [dev->mtl_queue release]; + dev->mtl_queue = nil; + } + + if (dev->mtl_device) { + [dev->mtl_device release]; + dev->mtl_device = nil; + } + + free(dev); +} + +void * ggml_metal_device_get_obj(ggml_metal_device_t dev) { + return dev->mtl_device; +} + +void * ggml_metal_device_get_queue(ggml_metal_device_t dev) { + return dev->mtl_queue; +} + +ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev) { + return dev->library; +} + +void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) { + if (@available(macOS 10.12, iOS 16.0, *)) { + *total = dev->mtl_device.recommendedMaxWorkingSetSize; + *free = *total - dev->mtl_device.currentAllocatedSize; + } else { + *free = 0; + *total = 0; + } +} + +bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op) { + const bool has_simdgroup_mm = dev->props.has_simdgroup_mm; + const bool has_simdgroup_reduction = dev->props.has_simdgroup_reduction; + const bool has_bfloat = dev->props.has_bfloat; + + if (!has_bfloat) { + if (op->type == GGML_TYPE_BF16) { + return false; + } + + for (size_t i = 0, n = 3; i < n; ++i) { + if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { + return false; + } + } + } + + switch (op->op) { + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_EXP: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + default: + return false; + } + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + default: + return false; + } + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + case GGML_OP_CONCAT: + return true; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_ADD_ID: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ACC: + case GGML_OP_REPEAT: + case GGML_OP_SCALE: + case GGML_OP_CONV_TRANSPOSE_1D: + return true; + case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_LOG: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_SOFT_MAX: + case GGML_OP_GROUP_NORM: + return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: + return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + case GGML_OP_ARGMAX: + return has_simdgroup_reduction; + case GGML_OP_NORM: + return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + case GGML_OP_ROPE: + return true; + case GGML_OP_IM2COL: + return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); + case GGML_OP_POOL_1D: + return false; + case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + case GGML_OP_POOL_2D: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_PAD: + return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && + (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); + case GGML_OP_PAD_REFLECT_1D: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_ARGSORT: + case GGML_OP_LEAKY_RELU: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ARANGE: + return true; + case GGML_OP_FLASH_ATTN_EXT: + // for new head sizes, add checks here + if (op->src[0]->ne[0] != 40 && + op->src[0]->ne[0] != 64 && + op->src[0]->ne[0] != 80 && + op->src[0]->ne[0] != 96 && + op->src[0]->ne[0] != 112 && + op->src[0]->ne[0] != 128 && + op->src[0]->ne[0] != 192 && + op->src[0]->ne[0] != 256) { + return false; + } + if (op->src[0]->ne[0] == 576) { + // DeepSeek sizes + // TODO: disabled for now, until optmized + return false; + } + if (op->src[1]->type != op->src[2]->type) { + return false; + } + return has_simdgroup_mm; // TODO: over-restricted for vec-kernels + case GGML_OP_SSM_CONV: + case GGML_OP_SSM_SCAN: + return has_simdgroup_reduction; + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + return true; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + return has_simdgroup_reduction && + (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32); + case GGML_OP_CPY: + case GGML_OP_DUP: + case GGML_OP_CONT: + { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_I32: + return true; + default: + return false; + } + case GGML_TYPE_F16: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } + case GGML_TYPE_BF16: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_BF16: + return true; + default: + return false; + } + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } + case GGML_TYPE_I32: + return op->type == GGML_TYPE_F32; + default: + return false; + }; + } + case GGML_OP_GET_ROWS: + { + return op->ne[3] == 1; + } + case GGML_OP_SET_ROWS: + { + if (op->src[0]->type != GGML_TYPE_F32) { + return false; + } + + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_IQ4_NL: + return true; + default: + return false; + }; + } + default: + return false; + } +} + +const struct ggml_metal_device_props * ggml_metal_device_get_props(ggml_metal_device_t dev) { + return &dev->props; +} + +// +// device buffers +// + +// max memory buffers that can be mapped to the device +#define GGML_METAL_MAX_BUFFERS 64 + +struct ggml_metal_buffer_wrapper { + void * data; + size_t size; + + id metal; +}; + +struct ggml_metal_buffer { + void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985 + size_t all_size; + + // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host + bool is_shared; + bool owned; + + // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap + int n_buffers; + struct ggml_metal_buffer_wrapper buffers[GGML_METAL_MAX_BUFFERS]; + + bool use_residency_sets; + + // optional MTLResidencySet + // note: cannot use explicity "id" here because it is not available on certain OSes + id rset; + + // pointers to global device objects + id device; + id queue; +}; + +static void ggml_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef GGML_METAL_NDEBUG +#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) + if (@available(macOS 10.12, iOS 16.0, *)) { + GGML_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0, + device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + + if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) { + GGML_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); + } + } else { + GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); + } +#endif +#endif + GGML_UNUSED(device); + GGML_UNUSED(size_aligned); +} + +// rset init +static bool ggml_metal_buffer_rset_init(ggml_metal_buffer_t buf) { + buf->rset = nil; + + if (!buf->use_residency_sets) { + return true; + } + +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { + MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init]; + desc.label = @"ggml_metal"; + desc.initialCapacity = buf->n_buffers; + + NSError * error; + buf->rset = [buf->device newResidencySetWithDescriptor:desc error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + [desc release]; + return false; + } + + [desc release]; + + for (int i = 0; i < buf->n_buffers; i++) { + [buf->rset addAllocation:buf->buffers[i].metal]; + } + + [buf->rset commit]; + [buf->rset requestResidency]; + + return true; + } +#endif + + return true; +} + +// rset free +static void ggml_metal_buffer_rset_free(ggml_metal_buffer_t buf) { +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { + if (buf->rset) { + [buf->rset endResidency]; + [buf->rset removeAllAllocations]; + [buf->rset release]; + } + } +#else + GGML_UNUSED(buf); +#endif +} + +static void * ggml_metal_host_malloc(size_t n) { + void * data = NULL; + +#if TARGET_OS_OSX + kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE); + if (err != KERN_SUCCESS) { + GGML_LOG_ERROR("%s: error: vm_allocate failed\n", __func__); + return NULL; + } +#else + const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); + if (result != 0) { + GGML_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); + return NULL; + } +#endif + + return data; +} + +ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared) { + ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer)); + + const size_t size_page = sysconf(_SC_PAGESIZE); + + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + + shared = shared && props_dev->use_shared_buffers; + + // allocate shared buffer if the device supports it and it is required by the buffer type + if (shared) { + res->all_data = ggml_metal_host_malloc(size_aligned); + res->is_shared = true; + res->owned = true; + } else { + // dummy, non-NULL value - we'll populate this after creating the Metal buffer below + res->all_data = (void *) 0x000000400ULL; + res->is_shared = false; + } + res->all_size = size_aligned; + + res->device = ggml_metal_device_get_obj(dev); + res->queue = ggml_metal_device_get_queue(dev); + + res->n_buffers = 1; + + if (res->all_data != NULL) { + res->buffers[0].size = size; + res->buffers[0].metal = nil; + + if (size_aligned > 0) { + if (props_dev->use_shared_buffers &&shared) { + res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data + length:size_aligned + options:MTLResourceStorageModeShared + deallocator:nil]; + } else { + res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate]; + + res->all_data = (void *) (res->buffers[0].metal.gpuAddress); + } + } + + res->buffers[0].data = res->all_data; + } + + if (size_aligned > 0 && (res->all_data == NULL || res->buffers[0].metal == nil)) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + free(res); + return NULL; + } + + res->use_residency_sets = props_dev->use_residency_sets; + + if (!ggml_metal_buffer_rset_init(res)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(res); + return NULL; + } + + //ggml_metal_log_allocated_size(device, size_aligned); + + return res; +} + +ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size) { + ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer)); + + res->all_data = ptr; + res->all_size = size; + + res->is_shared = true; + res->owned = false; + + res->n_buffers = 0; + + const size_t size_page = sysconf(_SC_PAGESIZE); + + // page-align the data ptr + { + const uintptr_t offs = (uintptr_t) ptr % size_page; + ptr = (void *) ((char *) ptr - offs); + size += offs; + } + + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + res->device = ggml_metal_device_get_obj(dev); + res->queue = ggml_metal_device_get_queue(dev); + + const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + + // the buffer fits into the max buffer size allowed by the device + if (size_aligned <= props_dev->max_buffer_size) { + res->buffers[res->n_buffers].data = ptr; + res->buffers[res->n_buffers].size = size; + res->buffers[res->n_buffers].metal = nil; + + if (size_aligned > 0) { + res->buffers[res->n_buffers].metal = [res->device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (res->buffers[res->n_buffers].metal == nil) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + free(res); + return NULL; + } + } + + ggml_metal_log_allocated_size(res->device, size_aligned); + + ++res->n_buffers; + } else { + // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into + // one of the views + const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case + const size_t size_step = props_dev->max_buffer_size - size_ovlp; + const size_t size_view = props_dev->max_buffer_size; + + for (size_t i = 0; i < size; i += size_step) { + const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); + + res->buffers[res->n_buffers].data = (void *) ((uint8_t *) ptr + i); + res->buffers[res->n_buffers].size = size_step_aligned; + res->buffers[res->n_buffers].metal = nil; + + if (size_step_aligned > 0) { + res->buffers[res->n_buffers].metal = [res->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (res->buffers[res->n_buffers].metal == nil) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); + free(res); + return NULL; + } + } + + ggml_metal_log_allocated_size(res->device, size_step_aligned); + + if (i + size_step < size) { + GGML_LOG_INFO("\n"); + } + + ++res->n_buffers; + } + } + + res->use_residency_sets = props_dev->use_residency_sets; + + if (!ggml_metal_buffer_rset_init(res)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(res); + return NULL; + } + + return res; +} + +void ggml_metal_buffer_free(ggml_metal_buffer_t buf) { + for (int i = 0; i < buf->n_buffers; i++) { + [buf->buffers[i].metal release]; + } + + ggml_metal_buffer_rset_free(buf); + + if (buf->is_shared && buf->owned) { +#if TARGET_OS_OSX + vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)buf->all_data, buf->all_size); +#else + free(buf->all_data); +#endif + } + + free(buf); +} + +void * ggml_metal_buffer_get_base(ggml_metal_buffer_t buf) { + return buf->all_data; +} + +bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) { + return buf->is_shared; +} + +void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + if (buf->is_shared) { + memset((char *)tensor->data + offset, value, size); + return; + } + + @autoreleasepool { + // dst + struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor); + bid_dst.offs += offset; + + id queue = buf->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder fillBuffer:bid_dst.metal + range:NSMakeRange(bid_dst.offs, bid_dst.offs + size) + value:value]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + +void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + if (buf->is_shared) { + memcpy((char *)tensor->data + offset, data, size); + return; + } + + @autoreleasepool { + // src + void * data_ptr = (void *)(uintptr_t) data; // "const cast" the src data + id buf_src = [buf->device newBufferWithBytesNoCopy:data_ptr + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + + // dst + struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor); + bid_dst.offs += offset; + + // note: for experimentation purposes, here we use a semaphore to wait for the copy to complete + // this is alternative to waitUntilCompleted, which should be faster, but don't seem to make much difference + dispatch_semaphore_t completion_semaphore = dispatch_semaphore_create(0); + + id queue = buf->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:0 + toBuffer:bid_dst.metal + destinationOffset:bid_dst.offs + size:size]; + + [encoder endEncoding]; + } + + [cmd_buf addCompletedHandler:^(id cb) { + // TODO: can check for errors here + GGML_UNUSED(cb); + + dispatch_semaphore_signal(completion_semaphore); + }]; + + [cmd_buf commit]; + + dispatch_semaphore_wait(completion_semaphore, DISPATCH_TIME_FOREVER); + dispatch_release(completion_semaphore); + + //[cmd_buf waitUntilCompleted]; + } +} + +void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + if (buf->is_shared) { + memcpy(data, (const char *)tensor->data + offset, size); + return; + } + + @autoreleasepool { + // src + struct ggml_metal_buffer_id bid_src = ggml_metal_buffer_get_id(buf, tensor); + bid_src.offs += offset; + + // dst + id buf_dst = [buf->device newBufferWithBytesNoCopy:data + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + + id queue = buf->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:bid_src.metal + sourceOffset:bid_src.offs + toBuffer:buf_dst + destinationOffset:0 + size:size]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + +void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) { + if (buf->is_shared) { + memset(buf->all_data, value, buf->all_size); + return; + } + + @autoreleasepool { + id queue = buf->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder fillBuffer:buf->buffers[0].metal + range:NSMakeRange(0, buf->buffers[0].size) + value:value]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + +struct ggml_metal_buffer_id ggml_metal_buffer_get_id(ggml_metal_buffer_t buf, const struct ggml_tensor * t) { + struct ggml_metal_buffer_id res = { nil, 0 }; + + const int64_t tsize = ggml_nbytes(t); + + // find the view that contains the tensor fully + for (int i = 0; i < buf->n_buffers; ++i) { + const int64_t ioffs = (int64_t) t->data - (int64_t) buf->buffers[i].data; + + //GGML_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf->buffers[i].size); + if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf->buffers[i].size) { + res.metal = buf->buffers[i].metal; + res.offs = (size_t) ioffs; + + //GGML_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs); + + return res; + } + } + + GGML_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name); + + return res; +} diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index fc6526d6d5d..fcebd8dfdf5 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -8,6 +8,9 @@ // // TODO: for optimal performance, become function of the device and work size +#define N_R0_F 2 +#define N_SG_F 4 + #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 @@ -20,8 +23,8 @@ #define N_R0_Q5_1 4 #define N_SG_Q5_1 2 -#define N_R0_Q8_0 4 -#define N_SG_Q8_0 2 +#define N_R0_Q8_0 2 +#define N_SG_Q8_0 4 #define N_R0_MXFP4 2 #define N_SG_MXFP4 2 @@ -32,13 +35,13 @@ #define N_R0_Q3_K 2 #define N_SG_Q3_K 2 -#define N_R0_Q4_K 4 +#define N_R0_Q4_K 2 #define N_SG_Q4_K 2 #define N_R0_Q5_K 2 #define N_SG_Q5_K 2 -#define N_R0_Q6_K 1 +#define N_R0_Q6_K 2 #define N_SG_Q6_K 2 #define N_R0_IQ1_S 4 @@ -68,6 +71,12 @@ #define N_R0_IQ4_XS 2 #define N_SG_IQ4_XS 2 +// function constants offsets +#define FC_FLASH_ATTN_EXT 100 +#define FC_FLASH_ATTN_EXT_VEC 200 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300 +#define FC_MUL_MV 400 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage @@ -160,6 +169,16 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_repeat; +typedef struct { + float scale; + float bias; +} ggml_metal_kargs_scale; + +typedef struct { + float min; + float max; +} ggml_metal_kargs_clamp; + typedef struct { int64_t ne00; int64_t ne01; @@ -236,9 +255,11 @@ typedef struct { int32_t ne11; int32_t ne_12_2; // assume K and V are same shape int32_t ne_12_3; + int32_t ns10; uint64_t nb11; uint64_t nb12; uint64_t nb13; + int32_t ns20; uint64_t nb21; uint64_t nb22; uint64_t nb23; @@ -249,6 +270,7 @@ typedef struct { uint64_t nb33; int32_t ne1; int32_t ne2; + int32_t ne3; float scale; float max_bias; float m0; @@ -257,6 +279,44 @@ typedef struct { float logit_softcap; } ggml_metal_kargs_flash_attn_ext; +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + int32_t ns10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ns20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; + int32_t ne1; + int32_t ne2; + int32_t ne3; + float scale; + float max_bias; + float m0; + float m1; + int32_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext_vec; + +typedef struct { + int32_t nrows; +} ggml_metal_kargs_flash_attn_ext_vec_reduce; + typedef struct { int32_t ne00; int32_t ne02; @@ -314,46 +374,34 @@ typedef struct { int32_t ne1; int16_t r2; int16_t r3; - int16_t nsg; - int16_t nxpsg; - int16_t r1ptg; } ggml_metal_kargs_mul_mv_ext; typedef struct { + int32_t ne02; int32_t ne10; int32_t ne11; // n_expert_used (bcast) uint64_t nb11; uint64_t nb12; - int32_t neh11; // n_tokens - uint64_t nbh11; + int32_t ne21; // n_tokens int32_t ne20; // n_expert_used uint64_t nb21; } ggml_metal_kargs_mul_mm_id_map0; -typedef struct { - int32_t ne20; // n_expert_used - int32_t neh0; - int32_t neh1; - uint64_t nbh1; - uint64_t nbh2; - int32_t ne0; - uint64_t nb1; - uint64_t nb2; -} ggml_metal_kargs_mul_mm_id_map1; - typedef struct { int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; uint64_t nb03; - int32_t neh12; - uint64_t nbh10; - uint64_t nbh11; - uint64_t nbh12; - uint64_t nbh13; - int32_t neh0; - int32_t neh1; + int32_t ne11; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne20; + int32_t ne21; + int32_t ne0; + int32_t ne1; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mm_id; @@ -383,8 +431,16 @@ typedef struct { typedef struct { int32_t ne00; int32_t ne00_4; - uint64_t nb01; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; float eps; + int32_t nef1[3]; + int32_t nef2[3]; + int32_t nef3[3]; + uint64_t nbf1[3]; + uint64_t nbf2[3]; + uint64_t nbf3[3]; } ggml_metal_kargs_norm; typedef struct { @@ -416,7 +472,7 @@ typedef struct { uint64_t nb00; uint64_t nb01; uint64_t nb02; - int32_t n_groups; + int32_t ngrp; float eps; } ggml_metal_kargs_group_norm; @@ -469,14 +525,6 @@ typedef struct { uint64_t nb01; uint64_t nb02; uint64_t nb03; - int64_t ne10; - int64_t ne11; - int64_t ne12; - int64_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; int64_t ne0; int64_t ne1; int64_t ne2; @@ -510,12 +558,6 @@ typedef struct { int32_t n_head_log2; } ggml_metal_kargs_soft_max; -typedef struct { - int64_t ne00; - int64_t ne01; - int n_past; -} ggml_metal_kargs_diag_mask_inf; - typedef struct { int64_t ne00; int64_t ne01; @@ -542,7 +584,7 @@ typedef struct { int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; - int64_t s_off; + uint64_t s_off; uint64_t nb01; uint64_t nb02; uint64_t nb03; @@ -682,7 +724,12 @@ typedef struct { int64_t IW; int64_t OH; int64_t OW; - int64_t parallel_elements; + int64_t np; } ggml_metal_kargs_pool_2d; +typedef struct { + int64_t ne00; + uint64_t nb01; +} ggml_metal_kargs_argmax; + #endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp new file mode 100644 index 00000000000..eeffcd46c30 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -0,0 +1,3284 @@ +#include "ggml-metal-ops.h" + +#include "ggml.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "ggml-metal-impl.h" +#include "ggml-metal-common.h" +#include "ggml-metal-device.h" + +#include +#include + +static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) { + if (!t) { + return { nullptr, 0 }; + } + + ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; + + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t) buffer->context; + + return ggml_metal_buffer_get_id(ctx, t); +} + +struct ggml_metal_op { + ggml_metal_device_t dev; + ggml_metal_library_t lib; + ggml_metal_encoder_t enc; + ggml_mem_ranges_t mem_ranges; + + ggml_cgraph * gf; + + int idx_start; + int idx_end; + + bool use_fusion; + bool use_concurrency; + bool use_capture; + + int debug_graph; + int debug_fusion; +}; + +ggml_metal_op_t ggml_metal_op_init( + ggml_metal_device_t dev, + ggml_metal_cmd_buf_t cmd_buf, + ggml_cgraph * gf, + int idx_start, + int idx_end, + bool use_fusion, + bool use_concurrency, + bool use_capture, + int debug_graph, + int debug_fusion) { + ggml_metal_op_t res = new ggml_metal_op(); + + *res = { + /*.dev =*/ dev, + /*.lib =*/ ggml_metal_device_get_library(dev), + /*.enc =*/ ggml_metal_encoder_init(cmd_buf, use_concurrency), + /*.mem_ranges =*/ ggml_mem_ranges_init(debug_graph), + /*.gf =*/ gf, + /*.idx_start =*/ idx_start, + /*.idx_end =*/ idx_end, + /*.use_fusion =*/ use_fusion, + /*.use_concurrency =*/ use_concurrency, + /*.use_capture =*/ use_capture, + /*.debug_graph =*/ debug_graph, + /*.debug_fusion =*/ debug_fusion, + }; + + return res; +} + +void ggml_metal_op_free(ggml_metal_op_t ctx) { + ggml_metal_encoder_end_encoding(ctx->enc); + ggml_metal_encoder_free(ctx->enc); + ggml_mem_ranges_free(ctx->mem_ranges); + + delete ctx; +} + +static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) { + if (!ctx->mem_ranges) { + return true; + } + + ggml_metal_encoder_memory_barrier(ctx->enc); + + ggml_mem_ranges_reset(ctx->mem_ranges); + + return true; +} + +static bool ggml_metal_op_concurrency_check(ggml_metal_op_t ctx, const ggml_tensor * node) { + if (!ctx->mem_ranges) { + return false; + } + + return ggml_mem_ranges_check(ctx->mem_ranges, node); +} + +static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor * node) { + if (!ctx->mem_ranges) { + return true; + } + + return ggml_mem_ranges_add(ctx->mem_ranges, node); +} + +static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { + struct ggml_cgraph * gf = ctx->gf; + + struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx; + struct ggml_tensor * node = nodes[0]; + + //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op)); + + if (ggml_is_empty(node)) { + return 1; + } + + switch (node->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + { + // noop -> next node + } return 1; + default: + { + } break; + } + + if (!ggml_metal_device_supports_op(ctx->dev, node)) { + GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(node)); + GGML_ABORT("unsupported op"); + } + + int n_fuse = 1; + + // check if the current node can run concurrently with other nodes before it + // the condition is that: + // - the current node cannot write to any previous src or dst ranges + // - the current node cannot read from any previous dst ranges + // + // if the condition is not satisfied, we put a memory barrier and clear all ranges + // otherwise, we add the new ranges to the encoding context and process the node concurrently + // + { + const bool is_concurrent = ggml_metal_op_concurrency_check(ctx, node); + + if (!is_concurrent) { + ggml_metal_op_concurrency_reset(ctx); + } + + if (ctx->debug_graph > 0) { + GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), is_concurrent ? "(concurrent)" : ""); + } + if (ctx->debug_graph > 1) { + GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb); + GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb); + GGML_TENSOR_LOCALS( int64_t, ne, node, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, node, nb); + + if (node->src[0]) { + GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[0]->type), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, + ggml_is_contiguous(node->src[0]), node->src[0]->name); + } + if (node->src[1]) { + GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, + ggml_is_contiguous(node->src[1]), node->src[1]->name); + } + if (node) { + GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, + node->name); + } + } + } + + switch (node->op) { + case GGML_OP_CONCAT: + { + n_fuse = ggml_metal_op_concat(ctx, idx); + } break; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + { + n_fuse = ggml_metal_op_bin(ctx, idx); + } break; + case GGML_OP_ADD_ID: + { + n_fuse = ggml_metal_op_add_id(ctx, idx); + } break; + case GGML_OP_REPEAT: + { + n_fuse = ggml_metal_op_repeat(ctx, idx); + } break; + case GGML_OP_ACC: + { + n_fuse = ggml_metal_op_acc(ctx, idx); + } break; + case GGML_OP_SCALE: + { + n_fuse = ggml_metal_op_scale(ctx, idx); + } break; + case GGML_OP_CLAMP: + { + n_fuse = ggml_metal_op_clamp(ctx, idx); + } break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_LOG: + case GGML_OP_UNARY: + { + n_fuse = ggml_metal_op_unary(ctx, idx); + } break; + case GGML_OP_GLU: + { + n_fuse = ggml_metal_op_glu(ctx, idx); + } break; + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + { + n_fuse = ggml_metal_op_sum_rows(ctx, idx); + } break; + case GGML_OP_SOFT_MAX: + { + n_fuse = ggml_metal_op_soft_max(ctx, idx); + } break; + case GGML_OP_SSM_CONV: + { + n_fuse = ggml_metal_op_ssm_conv(ctx, idx); + } break; + case GGML_OP_SSM_SCAN: + { + n_fuse = ggml_metal_op_ssm_scan(ctx, idx); + } break; + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + { + n_fuse = ggml_metal_op_rwkv(ctx, idx); + } break; + case GGML_OP_MUL_MAT: + { + n_fuse = ggml_metal_op_mul_mat(ctx, idx); + } break; + case GGML_OP_MUL_MAT_ID: + { + n_fuse = ggml_metal_op_mul_mat_id(ctx, idx); + } break; + case GGML_OP_GET_ROWS: + { + n_fuse = ggml_metal_op_get_rows(ctx, idx); + } break; + case GGML_OP_SET_ROWS: + { + n_fuse = ggml_metal_op_set_rows(ctx, idx); + } break; + case GGML_OP_RMS_NORM: + { + n_fuse = ggml_metal_op_rms_norm(ctx, idx); + } break; + case GGML_OP_L2_NORM: + { + n_fuse = ggml_metal_op_l2_norm(ctx, idx); + } break; + case GGML_OP_GROUP_NORM: + { + n_fuse = ggml_metal_op_group_norm(ctx, idx); + } break; + case GGML_OP_NORM: + { + n_fuse = ggml_metal_op_norm(ctx, idx); + } break; + case GGML_OP_ROPE: + { + n_fuse = ggml_metal_op_rope(ctx, idx); + } break; + case GGML_OP_IM2COL: + { + n_fuse = ggml_metal_op_im2col(ctx, idx); + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx); + } break; + case GGML_OP_UPSCALE: + { + n_fuse = ggml_metal_op_upscale(ctx, idx); + } break; + case GGML_OP_PAD: + { + n_fuse = ggml_metal_op_pad(ctx, idx); + } break; + case GGML_OP_PAD_REFLECT_1D: + { + n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx); + } break; + case GGML_OP_ARANGE: + { + n_fuse = ggml_metal_op_arange(ctx, idx); + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + n_fuse = ggml_metal_op_timestep_embedding(ctx, idx); + } break; + case GGML_OP_ARGSORT: + { + n_fuse = ggml_metal_op_argsort(ctx, idx); + } break; + case GGML_OP_LEAKY_RELU: + { + n_fuse = ggml_metal_op_leaky_relu(ctx, idx); + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx); + } break; + case GGML_OP_DUP: + case GGML_OP_CPY: + case GGML_OP_CONT: + { + n_fuse = ggml_metal_op_cpy(ctx, idx); + } break; + case GGML_OP_POOL_2D: + { + n_fuse = ggml_metal_op_pool_2d(ctx, idx); + } break; + case GGML_OP_ARGMAX: + { + n_fuse = ggml_metal_op_argmax(ctx, idx); + } break; + default: + { + GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op)); + GGML_ABORT("fatal error"); + } + } + + if (ctx->debug_graph > 0) { + if (n_fuse > 1) { + GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse); + } + } + + // update the mem ranges in the encoding context + for (int i = 0; i < n_fuse; ++i) { + if (!ggml_metal_op_concurrency_add(ctx, nodes[i])) { + ggml_metal_op_concurrency_reset(ctx); + } + } + + return n_fuse; +} + +int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) { + if (ctx->use_capture) { + ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ggml_graph_node(ctx->gf, idx))); + } + + int res = ggml_metal_op_encode_impl(ctx, idx); + if (idx + res > ctx->idx_end) { + GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s", + "https://github.com/ggml-org/llama.cpp/pull/14849"); + } + + if (ctx->use_capture) { + ggml_metal_encoder_debug_group_pop(ctx->enc); + } + + return res; +} + +int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t dim = ((const int32_t *) op->op_params)[0]; + + ggml_metal_kargs_concat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.dim =*/ dim, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const int nth = std::min(1024, ne0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); + + ggml_metal_kargs_repeat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + + const size_t pnb1 = ((const int32_t *) op->op_params)[0]; + const size_t pnb2 = ((const int32_t *) op->op_params)[1]; + const size_t pnb3 = ((const int32_t *) op->op_params)[2]; + const size_t offs = ((const int32_t *) op->op_params)[3]; + + const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + //const id pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_op_concurrency_reset(ctx); + } + + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ pnb1, + /*.nb02 =*/ pnb2, + /*.nb03 =*/ pnb3, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + /*.offs =*/ offs, + /*.o1 =*/ { 0 }, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float scale; + float bias; + memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float)); + memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float)); + + ggml_metal_kargs_scale args = { + /*.scale =*/ scale, + /*.bias =*/ bias, + }; + + int64_t n = ggml_nelements(op); + + if (n % 4 == 0) { + n /= 4; + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float min; + float max; + memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float)); + memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float)); + + ggml_metal_kargs_clamp args = { + /*.min =*/ min, + /*.max =*/ max, + }; + + int64_t n = ggml_nelements(op); + + if (n % 4 == 0) { + n /= 4; + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + int64_t n = ggml_nelements(op); + + if (n % 4 == 0) { + n /= 4; + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + if (op->src[1]) { + GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1])); + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_glu(lib, op); + + const int32_t swp = ggml_get_op_params_i32(op, 1); + const float alpha = ggml_get_op_params_f32(op, 2); + const float limit = ggml_get_op_params_f32(op, 3); + + const int32_t i00 = swp ? ne0 : 0; + const int32_t i10 = swp ? 0 : ne0; + + ggml_metal_kargs_glu args = { + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.ne10 =*/ op->src[1] ? ne10 : ne00, + /*.nb11 =*/ op->src[1] ? nb11 : nb01, + /*.ne0 =*/ ne0, + /*.nb1 =*/ nb1, + /*.i00 =*/ op->src[1] ? 0 : i00, + /*.i10 =*/ op->src[1] ? 0 : i10, + /*.alpha=*/ alpha, + /*.limit=*/ limit + }; + + const int64_t nrows = ggml_nrows(op->src[0]); + + const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2); + + //[encoder setComputePipelineState:pipeline]; + //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + //if (src1) { + // [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + //} else { + // [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + //} + //[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + //[encoder setBytes:&args length:sizeof(args) atIndex:3]; + + //[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + if (op->src[1]) { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + } else { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 2); + } + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_kargs_sum_rows args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + //[encoder setComputePipelineState:pipeline]; + //[encoder setBytes:&args length:sizeof(args) atIndex:0]; + //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + //[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + //[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + //[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); + + ggml_metal_kargs_get_rows args = { + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1); + + return 1; +} + +int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->type); + + const int32_t nk0 = ne0/ggml_blck_size(op->type); + + int nth = 32; // SIMD width + + while (nth < nk0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + int nrptg = 1; + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; + + if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nrptg--; + } + } + + nth = std::min(nth, nk0); + + ggml_metal_kargs_set_rows args = { + /*.nk0 =*/ nk0, + /*.ne01 =*/ ne01, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1); + + return 1; +} + +int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float scale; + float max_bias; + + memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias)); + + const uint32_t n_head = op->src[0]->ne[2]; + const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + // softmax + + ggml_metal_kargs_soft_max args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op); + + int nth = 32; // SIMD width + + if (ne00%4 == 0) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + } else { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + } + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + if (op->src[1]) { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + } else { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2); + } + if (op->src[2]) { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[2]), 3); + } else { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3); + } + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_kargs_ssm_conv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + GGML_TENSOR_LOCALS( int32_t, ne4, op->src[4], ne); + GGML_TENSOR_LOCALS(uint64_t, nb4, op->src[4], nb); + GGML_TENSOR_LOCALS( int32_t, ne5, op->src[5], ne); + GGML_TENSOR_LOCALS(uint64_t, nb5, op->src[5], nb); + GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne); + GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const ggml_tensor * src3 = op->src[3]; + const ggml_tensor * src4 = op->src[4]; + const ggml_tensor * src5 = op->src[5]; + const ggml_tensor * src6 = op->src[6]; + + GGML_ASSERT(src3); + GGML_ASSERT(src4); + GGML_ASSERT(src5); + GGML_ASSERT(src6); + + const int64_t d_state = ne00; + const int64_t d_inner = ne01; + const int64_t n_head = ne02; + const int64_t n_group = ne41; + const int64_t n_seq_tokens = ne12; + const int64_t n_seqs = ne13; + + ggml_metal_kargs_ssm_scan args = { + /*.d_state =*/ d_state, + /*.d_inner =*/ d_inner, + /*.n_head =*/ n_head, + /*.n_group =*/ n_group, + /*.n_seq_tokens =*/ n_seq_tokens, + /*.n_seqs =*/ n_seqs, + /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float), + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb31 =*/ nb31, + /*.nb41 =*/ nb41, + /*.nb42 =*/ nb42, + /*.nb43 =*/ nb43, + /*.nb51 =*/ nb51, + /*.nb52 =*/ nb52, + /*.nb53 =*/ nb53, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + + const size_t sms = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0); + + if (ne30 == 1) { + // Mamba-2 + ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); + } else { + GGML_ASSERT(d_inner == 1); + ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1); + } + + return 1; +} + +int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1]; + const int64_t T = op->src[0]->ne[2]; + const int64_t C = op->ne[0]; + const int64_t H = op->src[0]->ne[1]; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); + if (op->op == GGML_OP_RWKV_WKV7) { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++); + } + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++); + + ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1); + + return 1; +} + +int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + + GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); + + // TODO: support + //const int32_t nk00 = ne00/ggml_blck_size(op->type); + const int32_t nk00 = ne00; + + int nth = 32; // SIMD width + + while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + + // TODO: relax this constraint in the future + if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) { + if (nth > nk00) { + nrptg = (nth + nk00 - 1)/nk00; + nth = nk00; + + if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nrptg--; + } + } + } + + nth = std::min(nth, nk00); + + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ nk00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1); + + return 1; +} + +int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int32_t * opts = op->op_params; + ggml_op_pool op_pool = (ggml_op_pool) opts[0]; + + const int32_t k0 = opts[1]; + const int32_t k1 = opts[2]; + const int32_t s0 = opts[3]; + const int32_t s1 = opts[4]; + const int32_t p0 = opts[5]; + const int32_t p1 = opts[6]; + + const int64_t IH = op->src[0]->ne[1]; + const int64_t IW = op->src[0]->ne[0]; + + const int64_t N = op->ne[3]; + const int64_t OC = op->ne[2]; + const int64_t OH = op->ne[1]; + const int64_t OW = op->ne[0]; + + const int64_t np = N * OC * OH * OW; + + ggml_metal_kargs_pool_2d args_pool_2d = { + /* .k0 = */ k0, + /* .k1 = */ k1, + /* .s0 = */ s0, + /* .s1 = */ s1, + /* .p0 = */ p0, + /* .p1 = */ p1, + /* .IH = */ IH, + /* .IW = */ IW, + /* .OH = */ OH, + /* .OW = */ OW, + /* .np = */ np + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); + const int ntg = (np + nth - 1) / nth; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args_pool_2d, sizeof(args_pool_2d), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + GGML_ASSERT(ne00 == ne10); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + const int16_t r2 = ne12/ne02; + const int16_t r3 = ne13/ne03; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + const int ne11_mm_min = 8; + + // first try to use small-batch mat-mv kernels + // these should be efficient for BS [2, ~8] + if (op->src[1]->type == GGML_TYPE_F32 && (ne00%128 == 0) && + ( + ( + ( + op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function + op->src[0]->type == GGML_TYPE_F16 || + op->src[0]->type == GGML_TYPE_Q4_0 || + op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_Q5_0 || + op->src[0]->type == GGML_TYPE_Q5_1 || + op->src[0]->type == GGML_TYPE_Q8_0 || + op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_IQ4_NL || + false) && (ne11 >= 2 && ne11 <= 8) + ) || + ( + ( + op->src[0]->type == GGML_TYPE_Q4_K || + op->src[0]->type == GGML_TYPE_Q5_K || + op->src[0]->type == GGML_TYPE_Q6_K || + false) && (ne11 >= 4 && ne11 <= 8) + ) + ) + ) { + // TODO: determine the optimal parameters based on grid utilization + // I still don't know why we should not always use the maximum available threads: + // + // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32 + // + // my current hypothesis is that the work grid is not evenly divisible for different nsg + // values and there can be some tail effects when nsg is high. need to confirm this + // + const int nsg = 2; // num simdgroups per threadgroup + + // num threads along row per simdgroup + int16_t nxpsg = 0; + if (ne00 % 256 == 0 && ne11 < 3) { + nxpsg = 16; + } else if (ne00 % 128 == 0) { + nxpsg = 8; + } else { + nxpsg = 4; + } + + const int16_t nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time) + const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup + int16_t r1ptg = 4; // num src1 rows per threadgroup + + // note: not sure how optimal are those across all different hardware. there might be someting cleverer + switch (ne11) { + case 2: + r1ptg = 2; break; + case 3: + case 6: + r1ptg = 3; break; + case 4: + case 7: + case 8: + r1ptg = 4; break; + case 5: + r1ptg = 5; break; + default: + GGML_ABORT("unsupported ne11"); + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + + ggml_metal_kargs_mul_mv_ext args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + r0ptg - 1)/r0ptg), ((ne11 + r1ptg - 1)/r1ptg), ne12*ne13, 32, nsg, 1); + } else if ( + !ggml_is_transposed(op->src[0]) && + !ggml_is_transposed(op->src[1]) && + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + props_dev->has_simdgroup_mm && + op->src[1]->type == GGML_TYPE_F32 && + ne00 % 32 == 0 && ne00 >= 64 && + (ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) { + //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + switch (op->src[0]->type) { + case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; + default: break; + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op->src[0]->type, op->src[1]->type); + + ggml_metal_kargs_mul_mm args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1); + } else { + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); + + ggml_metal_kargs_mul_mv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); + const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); + const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16 || + op->src[0]->type == GGML_TYPE_BF16 || + op->src[0]->type == GGML_TYPE_Q8_0) { + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1); + } else { + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1); + } + } + + return 1; +} + +size_t ggml_metal_op_mul_mat_id_extra_tpe(const ggml_tensor * op) { + assert(op->op == GGML_OP_MUL_MAT_ID); + + const int64_t ne02 = op->src[0]->ne[2]; // n_expert + + return ggml_type_size(GGML_TYPE_I32)*ne02; +} + +size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) { + assert(op->op == GGML_OP_MUL_MAT_ID); + + const int64_t ne02 = op->src[0]->ne[2]; // n_expert + const int64_t ne21 = op->src[2]->ne[1]; // n_token + + return ggml_type_size(GGML_TYPE_I32)*ne02*ne21; +} + +int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + // src2 = ids + GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32); + + GGML_ASSERT(!ggml_is_transposed(op->src[0])); + GGML_ASSERT(!ggml_is_transposed(op->src[1])); + + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + const uint32_t r2 = 1; + const uint32_t r3 = 1; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + // ne20 = n_used_experts + // ne21 = n_rows (batch size) + const int ne21_mm_id_min = 32; + + if (props_dev->has_simdgroup_mm && + ne00 % 32 == 0 && ne00 >= 64 && + (ne21 >= ne21_mm_id_min)) { + GGML_ASSERT(ne00 % 4 == 0); + + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + switch (op->src[0]->type) { + case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; + default: break; + } + + // extra buffers for intermediate id mapping + ggml_metal_buffer_id bid_tpe = bid_dst; + bid_tpe.offs += ggml_nbytes(op); + + ggml_metal_buffer_id bid_ids = bid_tpe; + bid_ids.offs += ggml_metal_op_mul_mat_id_extra_tpe(op); + + { + ggml_metal_kargs_mul_mm_id_map0 args = { + ne02, + ne10, + ne11, // n_expert_used (bcast) + nb11, + nb12, + ne21, // n_tokens + ne20, // n_expert_used + nb21, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src2, 1); + ggml_metal_encoder_set_buffer (enc, bid_tpe, 2); + ggml_metal_encoder_set_buffer (enc, bid_ids, 3); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, ne02, 1, 1); + } + + // this barrier is always needed because the next kernel has to wait for the id maps to be computed + ggml_metal_op_concurrency_reset(ctx); + + { + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op->src[0]->type, GGML_TYPE_F16); + + ggml_metal_kargs_mul_mm_id args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, // n_expert_used (bcast) + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne20 =*/ ne20, // n_expert_used + /*.ne21 =*/ ne21, // n_tokens + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_tpe, 3); + ggml_metal_encoder_set_buffer (enc, bid_ids, 4); + ggml_metal_encoder_set_buffer (enc, bid_dst, 5); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); + } + } else { + ggml_metal_kargs_mul_mv_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb1 =*/ nb1, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); + + const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); + const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); + const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + if (ggml_is_quantized(op->src[0]->type)) { + GGML_ASSERT(ne00 >= nsg*nr0); + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, bid_src0, 1); + ggml_metal_encoder_set_buffer(enc, bid_src1, 2); + ggml_metal_encoder_set_buffer(enc, bid_dst, 3); + ggml_metal_encoder_set_buffer(enc, bid_src2, 4); + + const int64_t _ne1 = 1; + const int64_t ne123 = ne20*ne21; + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16 || + op->src[0]->type == GGML_TYPE_BF16 || + op->src[0]->type == GGML_TYPE_Q8_0) { + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1); + } else { + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1); + } + } + + return 1; +} + +int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_kargs_add_id args = { + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb11 =*/ nb11, + /*.nb21 =*/ nb21, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, 1, nth, 1, 1); + + return 1; +} + +bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + const int64_t ne00 = op->src[0]->ne[0]; // head size + const int64_t ne01 = op->src[0]->ne[1]; // batch size + + // use vec kernel if the batch size is small and if the head size is supported + return (ne01 < 20) && (ne00 % 32 == 0); +} + +size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + const int64_t nwg = 32; + + const int64_t ne01 = op->src[0]->ne[1]; + const int64_t ne02 = op->src[0]->ne[2]; + const int64_t ne03 = op->src[0]->ne[3]; + const int64_t ne20 = op->src[2]->ne[0]; + + // temp buffer for writing the results from each workgroup + // - ne20: the size of the Value head + // - + 2: the S and M values for each intermediate result + return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); +} + +int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS( int32_t, nb, op, nb); + + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ne11 % 32 == 0); + + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == op->src[2]->type); + + //GGML_ASSERT(ggml_are_same_shape (src1, src2)); + GGML_ASSERT(ne11 == ne21); + GGML_ASSERT(ne12 == ne22); + + GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16); + GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + + float scale; + float max_bias; + float logit_softcap; + + memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias)); + memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const bool has_mask = op->src[3] != NULL; + const bool has_sinks = op->src[4] != NULL; + const bool has_bias = max_bias != 0.0f; + const bool has_scap = logit_softcap != 0.0f; + + const uint32_t n_head = op->src[0]->ne[2]; + const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + GGML_ASSERT(ne01 < 65536); + + if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { + // half8x8 kernel + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0; + + // 2*(2*ncpsg) + // ncpsg soft_max values + ncpsg mask values + // + // 16*32*(nsg) + // the shared memory needed for the simdgroups to load the KV cache + // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) + + //int64_t nsgmax = 4; + // + //if (is_q) { + // nsgmax = 2; + // while (true) { + // const size_t smem = FATTN_SMEM(nsgmax); + // if (smem > props_dev->max_theadgroup_memory_size) { + // break; + // } + // nsgmax *= 2; + // } + // nsgmax /= 2; + //} + + // simdgroups per threadgroup (a.k.a. warps) + //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + int32_t nsg = 4; + + const size_t smem = FATTN_SMEM(nsg); + + ggml_metal_kargs_flash_attn_ext args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.ns10 =*/ int32_t(nb11/nb10), + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ns20 =*/ int32_t(nb21/nb20), + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); + if (op->src[3]) { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); + } else { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); + } + if (op->src[4]) { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); + } else { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); + } + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03, 32, nsg, 1); +#undef FATTN_SMEM + } else { + // half4x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + const int64_t nkpsg = 1*ncpsg; + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // ne00 + 2*ncpsg*(nsg) + // for each query, we load it as f16 in shared memory (ne00) + // and store the soft_max values and the mask + // + // ne20*(nsg) + // each simdgroup has a full f32 head vector in shared mem to accumulate results + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16)) + + int64_t nsgmax = 2; + while (true) { + const size_t smem = FATTN_SMEM(nsgmax); + // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes + if (smem > props_dev->max_theadgroup_memory_size/2) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); + const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32))); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + // workgroups + // each workgroup handles nsg*nkpsg cache values + int32_t nwg = 1; + if (false) { + // for small KV caches, we could launch a single workgroup and write the results directly to dst/ + // however, this does not lead to significant improvement, so disabled + nwg = 1; + nsg = 4; + } else { + nwg = 32; + nsg = 1; + while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) { + nsg *= 2; + } + } + + ggml_metal_kargs_flash_attn_ext_vec args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.ns10 =*/ int32_t(nb11/nb10), + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ns20 =*/ int32_t(nb21/nb20), + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg); + + GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); + if (op->src[3]) { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); + } else { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); + } + if (op->src[4]) { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); + } else { + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); + } + + const size_t smem = FATTN_SMEM(nsg); + + //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax); + GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); + + if (nwg == 1) { + // using 1 workgroup -> write the result directly into dst + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); + } else { + // sanity checks + GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); + GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31)); + + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + // write the results from each workgroup into a temp buffer + ggml_metal_buffer_id bid_tmp = bid_dst; + bid_tmp.offs += ggml_nbytes(op); + ggml_metal_encoder_set_buffer(enc, bid_tmp, 6); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); + + // sync the 2 kernels + ggml_metal_op_concurrency_reset(ctx); + + // reduce the results from the workgroups + { + const int32_t nrows = ne1*ne2*ne3; + + ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = { + nrows, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, 32*nwg, 1, 1); + } + } +#undef FATTN_SMEM + } + + return 1; +} + +int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_tensor ** ops = ggml_graph_nodes(gf) + idx; + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const int idx_end = ctx->idx_end; + + const bool use_fusion = ctx->use_fusion; + + const int debug_fusion = ctx->debug_fusion; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); + + bool bcast_row = false; + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.offs =*/ 0, + /*.o1 =*/ { bid_src1.offs }, + }; + + ggml_op fops[8]; + + int n_fuse = 1; + + // c[0] = add(a, b[0]) + // c[1] = add(c[0], b[1]) + // c[2] = add(c[1], b[2]) + // ... + if (use_fusion) { + fops[0] = GGML_OP_ADD; + fops[1] = GGML_OP_ADD; + fops[2] = GGML_OP_ADD; + fops[3] = GGML_OP_ADD; + fops[4] = GGML_OP_ADD; + fops[5] = GGML_OP_ADD; + fops[6] = GGML_OP_ADD; + fops[7] = GGML_OP_ADD; + + // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops + // across splits. idx_end indicates the last node in the current split + for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) { + if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) { + break; + } + + if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) { + break; + } + + // b[0] === b[1] === ... + if (!ggml_are_same_layout(ops[n_fuse]->src[1], ops[n_fuse + 1]->src[1])) { + break; + } + + // only fuse ops if src1 is in the same Metal buffer + ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]); + if (bid_fuse.metal != bid_src1.metal) { + break; + } + + //ctx->fuse_cnt[ops[n_fuse + 1]->op]++; + + args.o1[n_fuse + 1] = bid_fuse.offs; + } + + ++n_fuse; + + if (debug_fusion > 1 && n_fuse > 1) { + GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse); + } + } + + // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer + bid_src1.offs = 0; + + ggml_metal_pipeline_t pipeline = nullptr; + + if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true); + + bcast_row = true; + } else { + pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false); + } + + if (n_fuse > 1) { + bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]); + + for (int i = 1; i < n_fuse; ++i) { + if (!ggml_metal_op_concurrency_check(ctx, ops[i])) { + ggml_metal_op_concurrency_reset(ctx); + + break; + } + } + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_dst, 3); + + if (bcast_row) { + const int64_t n = ggml_nelements(op)/4; + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + } else { + int nth = 32; + + while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + } + + return n_fuse; +} + +int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const int idx_end = ctx->idx_end; + + const bool use_fusion = ctx->use_fusion; + + const int debug_fusion = ctx->debug_fusion; + + ggml_tensor ** ops = ggml_graph_nodes(gf) + idx; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float eps; + memcpy(&eps, op->op_params, sizeof(float)); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_kargs_rms_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.eps =*/ eps, + /*.nef1 =*/ { ne01 }, + /*.nef2 =*/ { ne02 }, + /*.nef3 =*/ { ne03 }, + /*.nbf1 =*/ { nb01 }, + /*.nbf2 =*/ { nb02 }, + /*.nbf3 =*/ { nb03 }, + }; + + ggml_op fops[8]; + + int n_fuse = 1; + + ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 }; + + // d[0] = rms_norm(a) + // d[1] = mul(d[0], b) + // d[2] = add(d[1], c) + if (use_fusion) { + fops[0] = GGML_OP_RMS_NORM; + fops[1] = GGML_OP_MUL; + fops[2] = GGML_OP_ADD; + + for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) { + if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) { + break; + } + + if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) { + break; + } + + if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) { + break; + } + + if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) { + break; + } + + if (ops[n_fuse + 1]->type != GGML_TYPE_F32) { + break; + } + + //ctx->fuse_cnt[ops[n_fuse + 1]->op]++; + + bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]); + + args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1]; + args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2]; + args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3]; + + args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1]; + args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2]; + args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3]; + } + + ++n_fuse; + + if (debug_fusion > 1 && n_fuse > 1) { + if (n_fuse == 2) { + GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__); + } + if (n_fuse == 3) { + GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__); + } + } + } + + if (n_fuse > 1) { + bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]); + + for (int i = 1; i < n_fuse; ++i) { + if (!ggml_metal_op_concurrency_check(ctx, ops[i])) { + ggml_metal_op_concurrency_reset(ctx); + + break; + } + } + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rms_norm(lib, op, n_fuse); + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00/4); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2); + ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3); + ggml_metal_encoder_set_buffer (enc, bid_dst, 4); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return n_fuse; +} + +int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float eps; + memcpy(&eps, op->op_params, sizeof(float)); + + int nth = 32; // SIMD width + + ggml_metal_kargs_l2_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); + + while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00/4); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + const int64_t nrows = ggml_nrows(op->src[0]); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int32_t ngrp = ((const int32_t *) op->op_params)[0]; + + float eps; + memcpy(&eps, op->op_params + 1, sizeof(float)); + + ggml_metal_kargs_group_norm args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ngrp =*/ ngrp, + /*.eps =*/ eps, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op); + + int nth = 32; // SIMD width + //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + // nth *= 2; + //} + + //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + //nth = std::min(nth, ne00/4); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ngrp, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const int idx_end = ctx->idx_end; + + const bool use_fusion = ctx->use_fusion; + + const int debug_fusion = ctx->debug_fusion; + + ggml_tensor ** ops = ggml_graph_nodes(gf) + idx; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float eps; + memcpy(&eps, op->op_params, sizeof(float)); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_kargs_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.eps =*/ eps, + /*.nef1 =*/ { ne01 }, + /*.nef2 =*/ { ne02 }, + /*.nef3 =*/ { ne03 }, + /*.nbf1 =*/ { nb01 }, + /*.nbf2 =*/ { nb02 }, + /*.nbf3 =*/ { nb03 }, + }; + + ggml_op fops[8]; + + int n_fuse = 1; + + ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 }; + + // d[0] = norm(a) + // d[1] = mul(d[0], b) + // d[2] = add(d[1], c) + if (use_fusion) { + fops[0] = GGML_OP_NORM; + fops[1] = GGML_OP_MUL; + fops[2] = GGML_OP_ADD; + + for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) { + if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) { + break; + } + + if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) { + break; + } + + if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) { + break; + } + + if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) { + break; + } + + if (ops[n_fuse + 1]->type != GGML_TYPE_F32) { + break; + } + + //ctx->fuse_cnt[ops[n_fuse + 1]->op]++; + + bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]); + + args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1]; + args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2]; + args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3]; + + args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1]; + args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2]; + args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3]; + } + + ++n_fuse; + + if (debug_fusion > 1 && n_fuse > 1) { + if (n_fuse == 2) { + GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__); + } + if (n_fuse == 3) { + GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__); + } + } + } + + if (n_fuse > 1) { + bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]); + + for (int i = 1; i < n_fuse; ++i) { + if (!ggml_metal_op_concurrency_check(ctx, ops[i])) { + ggml_metal_op_concurrency_reset(ctx); + + break; + } + } + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + nth = std::min(nth, ne00/4); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2); + ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3); + ggml_metal_encoder_set_buffer (enc, bid_dst, 4); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return n_fuse; +} + +int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + // make sure we have one or more position id(ne10) per token(ne02) + GGML_ASSERT(ne10 % ne02 == 0); + GGML_ASSERT(ne10 >= ne02); + + const int nth = std::min(1024, ne00); + + const int n_past = ((const int32_t *) op->op_params)[0]; + const int n_dims = ((const int32_t *) op->op_params)[1]; + //const int mode = ((const int32_t *) op->op_params)[2]; + // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal + const int n_ctx_orig = ((const int32_t *) op->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (const int32_t *) op->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (const int32_t *) op->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (const int32_t *) op->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (const int32_t *) op->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (const int32_t *) op->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (const int32_t *) op->op_params + 10, sizeof(float)); + + // mrope + const int sect_0 = ((const int32_t *) op->op_params)[11]; + const int sect_1 = ((const int32_t *) op->op_params)[12]; + const int sect_2 = ((const int32_t *) op->op_params)[13]; + const int sect_3 = ((const int32_t *) op->op_params)[14]; + + ggml_metal_kargs_rope args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.n_past =*/ n_past, + /*.n_dims =*/ n_dims, + /*.n_ctx_orig =*/ n_ctx_orig, + /*.freq_base =*/ freq_base, + /*.freq_scale =*/ freq_scale, + /*.ext_factor =*/ ext_factor, + /*.attn_factor =*/ attn_factor, + /*.beta_fast =*/ beta_fast, + /*.beta_slow =*/ beta_slow, + /* sect_0 =*/ sect_0, + /* sect_1 =*/ sect_1, + /* sect_2 =*/ sect_2, + /* sect_3 =*/ sect_3, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + if (op->src[2]) { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); + } else { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 3); + } + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + const int32_t s1 = ((const int32_t *)(op->op_params))[1]; + const int32_t p0 = ((const int32_t *)(op->op_params))[2]; + const int32_t p1 = ((const int32_t *)(op->op_params))[3]; + const int32_t d0 = ((const int32_t *)(op->op_params))[4]; + const int32_t d1 = ((const int32_t *)(op->op_params))[5]; + + const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1; + + const int32_t N = op->src[1]->ne[is_2D ? 3 : 2]; + const int32_t IC = op->src[1]->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? op->src[1]->ne[1] : 1; + const int32_t IW = op->src[1]->ne[0]; + + const int32_t KH = is_2D ? op->src[0]->ne[1] : 1; + const int32_t KW = op->src[0]->ne[0]; + + const int32_t OH = is_2D ? op->ne[2] : 1; + const int32_t OW = op->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4; + const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4; + + + ggml_metal_kargs_im2col args = { + /*.ofs0 =*/ ofs0, + /*.ofs1 =*/ ofs1, + /*.IW =*/ IW, + /*.IH =*/ IH, + /*.CHW =*/ CHW, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.p0 =*/ p0, + /*.p1 =*/ p1, + /*.d0 =*/ d0, + /*.d1 =*/ d1, + /*.N =*/ N, + /*.KH =*/ KH, + /*.KW =*/ KW, + /*.KHW =*/ KH * KW, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); + + GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); + + return 1; +} + +int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + + const int32_t IC = op->src[1]->ne[1]; + const int32_t IL = op->src[1]->ne[0]; + + const int32_t K = op->src[0]->ne[0]; + + const int32_t OL = op->ne[0]; + const int32_t OC = op->ne[1]; + + ggml_metal_kargs_conv_transpose_1d args = { + /*.IC =*/ IC, + /*.IL =*/ IL, + /*.K =*/ K, + /*.s0 =*/ s0, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1); + + return 1; +} + +int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const float sf0 = (float)ne0/op->src[0]->ne[0]; + const float sf1 = (float)ne1/op->src[0]->ne[1]; + const float sf2 = (float)ne2/op->src[0]->ne[2]; + const float sf3 = (float)ne3/op->src[0]->ne[3]; + + ggml_metal_kargs_upscale args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.sf0 =*/ sf0, + /*.sf1 =*/ sf1, + /*.sf2 =*/ sf2, + /*.sf3 =*/ sf3 + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_kargs_pad args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3 + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad(lib, op); + + const int nth = std::min(1024, ne0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_kargs_pad_reflect_1d args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.p0 =*/ ((const int32_t *)(op->op_params))[0], + /*.p1 =*/ ((const int32_t *)(op->op_params))[1] + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op); + + const int nth = std::min(1024, ne0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float start; + float step; + + memcpy(&start, ((const int32_t *) op->op_params) + 0, sizeof(float)); + memcpy(&step, ((const int32_t *) op->op_params) + 2, sizeof(float)); + + ggml_metal_kargs_arange args = { + /*.ne0 =*/ ne0, + /*.start =*/ start, + /*.step =*/ step + }; + + const int nth = std::min(1024, ne0); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op); + + //[encoder setComputePipelineState:pipeline]; + //[encoder setBuffer:id_dst offset:offs_dst atIndex:0]; + //[encoder setBytes:&args length:sizeof(args) atIndex:1]; + + //[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int dim = op->op_params[0]; + const int max_period = op->op_params[1]; + + ggml_metal_kargs_timestep_embedding args = { + /*.nb1 =*/ nb1, + /*.dim =*/ dim, + /*.max_period =*/ max_period, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op); + + const int nth = std::max(1, std::min(1024, dim/2)); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne00, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_kargs_argmax args = { + /*.ne00 = */ ne00, + /*.nb01 = */ nb01, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argmax(lib, op); + + const int64_t nrows = ggml_nrows(op->src[0]); + + int nth = 32; // SIMD width + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + // bitonic sort requires the number of elements to be power of 2 + int64_t ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); + + const int64_t nrows = ggml_nrows(op->src[0]); + + // Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + const size_t smem = GGML_PAD(ne00_padded*sizeof(int32_t), 16); + + ggml_metal_kargs_argsort args = { + /*.ncols =*/ ne00, + /*.ncols_pad =*/ ne00_padded + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1); + + return 1; +} + +int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { + ggml_cgraph * gf = ctx->gf; + ggml_tensor * op = ggml_graph_node(gf, idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + float slope; + memcpy(&slope, op->op_params, sizeof(float)); + + ggml_metal_kargs_leaky_relu args = { + /*.slope =*/ slope + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + + int64_t n = ggml_nelements(op); + + if (n % 4 == 0) { + n /= 4; + } + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + + return 1; +} diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h new file mode 100644 index 00000000000..b620de164d7 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -0,0 +1,81 @@ +#pragma once + +#include "ggml-metal-device.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct ggml_metal_op * ggml_metal_op_t; + +ggml_metal_op_t ggml_metal_op_init( + ggml_metal_device_t dev, + ggml_metal_cmd_buf_t cmd_buf, + struct ggml_cgraph * gf, + int idx_start, + int idx_end, + bool use_fusion, + bool use_concurrency, + bool use_capture, + int debug_graph, + int debug_fusion); + +void ggml_metal_op_free(ggml_metal_op_t ctx); + +int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx); + +// +// available ops: +// + +// tokens per expert +size_t ggml_metal_op_mul_mat_id_extra_tpe(const struct ggml_tensor * op); + +// id map [n_tokens, n_expert] +size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op); + +// return true if we should use the FA vector kernel for this op +bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op); + +size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); + +int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_rms_norm (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); +int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp new file mode 100644 index 00000000000..e11555a78fc --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -0,0 +1,718 @@ +#include "ggml-metal.h" + +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "ggml-metal-device.h" +#include "ggml-metal-context.h" +#include "ggml-metal-ops.h" + +// globals + +// initialized in ggml_backend_metal_reg +static ggml_backend_reg g_ggml_metal_reg; +static ggml_backend_device g_ggml_metal_device; + +//////////////////////////////////////////////////////////////////////////////// +// backend interface +//////////////////////////////////////////////////////////////////////////////// + +// shared buffer + +static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t buffer) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_free(ctx); +} + +static void * ggml_backend_metal_buffer_shared_get_base(ggml_backend_buffer_t buffer) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + return ggml_metal_buffer_get_base(ctx); +} + +static void ggml_backend_metal_buffer_shared_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size); +} + +static void ggml_backend_metal_buffer_shared_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size); +} + +static void ggml_backend_metal_buffer_shared_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size); +} + +static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + GGML_UNUSED(buffer); + GGML_UNUSED(src); + GGML_UNUSED(dst); + + return false; +} + +static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_clear(ctx, value); +} + +static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = { + /* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_shared_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, + /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_shared_clear, + /* .reset = */ NULL, +}; + +// private buffer + +static void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t buffer) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_free(ctx); +} + +static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + return ggml_metal_buffer_get_base(ctx); +} + +static void ggml_backend_metal_buffer_private_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size); +} + +static void ggml_backend_metal_buffer_private_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size); +} + +static void ggml_backend_metal_buffer_private_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size); +} + +static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + GGML_UNUSED(buffer); + GGML_UNUSED(src); + GGML_UNUSED(dst); + + return false; +} + +static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; + + GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); + + ggml_metal_buffer_clear(ctx, value); +} + +static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { + /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_private_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, + /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_private_clear, + /* .reset = */ NULL, +}; + +// +// buffer types +// + +// common method for allocating shread or private Metal buffers +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; + ggml_metal_buffer_t res = ggml_metal_buffer_init(ctx_dev, size, shared); + + ggml_backend_buffer_i buf_i = ggml_metal_buffer_is_shared(res) + ? ggml_backend_metal_buffer_shared_i + : ggml_backend_metal_buffer_private_i; + + return ggml_backend_buffer_init(buft, buf_i, res, size); +} + +static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + size_t res = ggml_nbytes(tensor); + + // some operations require additional memory for fleeting data: + switch (tensor->op) { + case GGML_OP_MUL_MAT_ID: + { + res += ggml_metal_op_mul_mat_id_extra_tpe(tensor); + res += ggml_metal_op_mul_mat_id_extra_ids(tensor); + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) { + res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); + } + } break; + default: + break; + } + + return res; + + GGML_UNUSED(buft); +} + +// default (shared) buffer type + +static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) { + return "Metal"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true); +} + +static size_t ggml_backend_metal_buffer_type_shared_get_alignment(ggml_backend_buffer_type_t buft) { + return 32; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; + + return ggml_metal_device_get_props(ctx_dev)->max_buffer_size; +} + +static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor); +} + +static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) { + static ggml_backend_buffer_type ggml_backend_buffer_type_metal = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, + }, + /* .device = */ &g_ggml_metal_device, + /* .context = */ NULL, + }; + + return &ggml_backend_buffer_type_metal; +} + +// default (private) buffer type + +static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) { + return "Metal_Private"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, false); +} + +static size_t ggml_backend_metal_buffer_type_private_get_alignment(ggml_backend_buffer_type_t buft) { + return 32; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; + + return ggml_metal_device_get_props(ctx_dev)->max_buffer_size; +} + +static size_t ggml_backend_metal_buffer_type_private_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor); +} + +static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) { + static ggml_backend_buffer_type ggml_backend_buffer_type_metal = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, + }, + /* .device = */ &g_ggml_metal_device, + /* .context = */ NULL, + }; + + return &ggml_backend_buffer_type_metal; +} + +// mapped buffer type + +static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) { + return "Metal_Mapped"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + // for mapped buffers, prefer shared memory + return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true); +} + +static size_t ggml_backend_metal_buffer_type_mapped_get_alignment(ggml_backend_buffer_type_t buft) { + return 32; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; + + return ggml_metal_device_get_props(ctx_dev)->max_buffer_size; +} + +static size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor); +} + +static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) { + // note: not obvious, but this buffer type still needs to implement .alloc_buffer: + // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099 + static ggml_backend_buffer_type ggml_backend_buffer_type_mapped_metal = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, + }, + /* .device = */ &g_ggml_metal_device, + /* .context = */ NULL, + }; + + return &ggml_backend_buffer_type_mapped_metal; +} + +// backend + +static const char * ggml_backend_metal_name(ggml_backend_t backend) { + return "Metal"; + + GGML_UNUSED(backend); +} + +static void ggml_backend_metal_free(ggml_backend_t backend) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + // wait for any ongoing async operations to finish + ggml_metal_synchronize(ctx); + + ggml_metal_free(ctx); + + free(backend); +} + +static void ggml_backend_metal_synchronize(ggml_backend_t backend) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_synchronize(ctx); +} + +static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_set_tensor_async(ctx, tensor, data, offset, size); +} + +static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_get_tensor_async(ctx, tensor, data, offset, size); +} + +static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { + return false; + + GGML_UNUSED(backend_src); + GGML_UNUSED(backend_dst); + GGML_UNUSED(src); + GGML_UNUSED(dst); +} + +static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + return ggml_metal_graph_compute(ctx, cgraph); +} + +static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_graph_optimize(ctx, cgraph); +} + +static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_set_n_cb(ctx, n_cb); + +} + +static ggml_backend_i ggml_backend_metal_i = { + /* .get_name = */ ggml_backend_metal_name, + /* .free = */ ggml_backend_metal_free, + /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, + /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups + /* .synchronize = */ ggml_backend_metal_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_metal_graph_compute, + + // the events API is needed only for multi-GPU setups, so likely no need to implement it for Metal + // in any case, these docs seem relevant if we ever decide to implement it: + // https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ ggml_backend_metal_graph_optimize, +}; + +static ggml_guid_t ggml_backend_metal_guid(void) { + static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 }; + return &guid; +} + +ggml_backend_t ggml_backend_metal_init(void) { + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0); + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_t ctx = ggml_metal_init(ctx_dev); + if (ctx == NULL) { + GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); + return NULL; + } + + ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend)); + + *backend = { + /* .guid = */ ggml_backend_metal_guid(), + /* .interface = */ ggml_backend_metal_i, + /* .device = */ dev, + /* .context = */ ctx, + }; + + ggml_backend_metal_set_n_cb(backend, 1); + + return backend; +} + +bool ggml_backend_is_metal(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid()); +} + +void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_set_abort_callback(ctx, abort_callback, user_data); +} + +bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + return ggml_metal_supports_family(ctx, family); +} + +void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + ggml_metal_t ctx = (ggml_metal_t)backend->context; + + ggml_metal_capture_next_compute(ctx); +} + +// backend device + +static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) { + return "Metal"; + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + return ggml_metal_device_get_props(ctx_dev)->name; +} + +static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_device_get_memory(ctx_dev, free, total); +} + +static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_GPU; + + GGML_UNUSED(dev); +} + +static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_metal_device_get_name(dev); + props->description = ggml_backend_metal_device_get_description(dev); + props->type = ggml_backend_metal_device_get_type(dev); + + ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); + + props->caps = { + /* .async = */ true, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_t ctx = ggml_metal_init(ctx_dev); + if (ctx == NULL) { + GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); + return NULL; + } + + ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend)); + + *backend = { + /* .guid = */ ggml_backend_metal_guid(), + /* .interface = */ ggml_backend_metal_i, + /* .device = */ dev, + /* .context = */ ctx, + }; + + ggml_backend_metal_set_n_cb(backend, 1); + + return backend; + + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); + + return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared() : ggml_backend_metal_buffer_type_private(); +} + +static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size); + + return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(), ggml_backend_metal_buffer_shared_i, res, size); +} + +static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + return ggml_metal_device_supports_op(ctx_dev, op); +} + +static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return + buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name || + buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name || + buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name; + + GGML_UNUSED(dev); +} + +static int64_t get_op_batch_size(const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_MUL_MAT: + return op->ne[1]; + case GGML_OP_MUL_MAT_ID: + return op->ne[2]; + default: + return ggml_nrows(op); + } +} + +static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + const int min_batch_size = 32; + + return (op->op == GGML_OP_MUL_MAT || + op->op == GGML_OP_MUL_MAT_ID) && + get_op_batch_size(op) >= min_batch_size; + + GGML_UNUSED(dev); + GGML_UNUSED(op); +} + +static ggml_backend_device_i ggml_backend_metal_device_i = { + /* .get_name = */ ggml_backend_metal_device_get_name, + /* .get_description = */ ggml_backend_metal_device_get_description, + /* .get_memory = */ ggml_backend_metal_device_get_memory, + /* .get_type = */ ggml_backend_metal_device_get_type, + /* .get_props = */ ggml_backend_metal_device_get_props, + /* .init_backend = */ ggml_backend_metal_device_init, + /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped, + /* .supports_op = */ ggml_backend_metal_device_supports_op, + /* .supports_buft = */ ggml_backend_metal_device_supports_buft, + /* .offload_op = */ ggml_backend_metal_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend registry + +static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) { + return "Metal"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) { + return 1; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + + return &g_ggml_metal_device; + + GGML_UNUSED(reg); + GGML_UNUSED(index); +} + +static ggml_backend_feature g_ggml_backend_metal_features[] = { +#if defined(GGML_METAL_EMBED_LIBRARY) + { "EMBED_LIBRARY", "1" }, +#endif + { NULL, NULL }, +}; + +static ggml_backend_feature * ggml_backend_metal_get_features(ggml_backend_reg_t reg) { + return g_ggml_backend_metal_features; + + GGML_UNUSED(reg); +} + +static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (strcmp(name, "ggml_backend_get_features") == 0) { + return (void *)ggml_backend_metal_get_features; + } + + return NULL; + + GGML_UNUSED(reg); +} + +static ggml_backend_reg_i ggml_backend_metal_reg_i = { + /* .get_name = */ ggml_backend_metal_reg_get_name, + /* .device_count = */ ggml_backend_metal_reg_device_count, + /* .device_get = */ ggml_backend_metal_reg_device_get, + /* .get_proc_address = */ ggml_backend_metal_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_metal_reg(void) { + { + g_ggml_metal_reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_metal_reg_i, + /* .context = */ NULL, + }; + + g_ggml_metal_device = { + /* .iface = */ ggml_backend_metal_device_i, + /* .reg = */ &g_ggml_metal_reg, + /* .context = */ ggml_metal_device_get(), + }; + } + + return &g_ggml_metal_reg; +} + +GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m deleted file mode 100644 index cb8eff4a772..00000000000 --- a/ggml/src/ggml-metal/ggml-metal.m +++ /dev/null @@ -1,6775 +0,0 @@ -#import "ggml-metal.h" - -#import "ggml-impl.h" -#import "ggml-backend-impl.h" -#import "ggml-metal-impl.h" - -#import - -#import - -#undef MIN -#undef MAX -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -// max memory buffers that can be mapped to the device -#define GGML_METAL_MAX_BUFFERS 64 - -// max number of MTLCommandBuffer used to submit a graph for processing -#define GGML_METAL_MAX_COMMAND_BUFFERS 8 - -#ifndef TARGET_OS_VISION -#define TARGET_OS_VISION 0 -#endif - -// create residency sets only on macOS >= 15.0 -#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \ - TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \ - TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \ - TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000 -#define GGML_METAL_HAS_RESIDENCY_SETS 1 -#endif - -// globals - -// overload of MTLGPUFamilyMetal3 (not available in some environments) -static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; - -// initialized in ggml_backend_metal_reg -static struct ggml_backend_reg g_ggml_backend_metal_reg; -static struct ggml_backend_device g_ggml_backend_metal_device; - -// information about a Metal device -// note: assumes single GPU device - the default one -// TODO: support multiple GPU devices -static struct ggml_backend_metal_device_context { - id mtl_device; - int mtl_device_ref_count; - id mtl_library; - - NSLock * mtl_lock; - - bool has_simdgroup_reduction; - bool has_simdgroup_mm; - bool has_residency_sets; - bool has_bfloat; - bool use_bfloat; - bool use_fusion; - - int debug_fusion; - - // how many times a given op was fused - uint64_t fuse_cnt[GGML_OP_COUNT]; - - size_t max_size; - - char name[128]; -} g_ggml_ctx_dev_main = { - /*.mtl_device =*/ nil, - /*.mtl_device_ref_count =*/ 0, - /*.mtl_library =*/ nil, - /*.mtl_lock =*/ nil, - /*.has_simdgroup_reduction =*/ false, - /*.has_simdgroup_mm =*/ false, - /*.has_residency_sets =*/ false, - /*.has_bfloat =*/ false, - /*.use_bfloat =*/ false, - /*.use_fusion =*/ true, - /*.debug_fusion =*/ 0, - /*.fuse_cnt =*/ { 0 }, - /*.max_size =*/ 0, - /*.name =*/ "", -}; - -// acquire -static id ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) { - assert(ctx != NULL); - - if (ctx->mtl_lock == nil) { - ctx->mtl_lock = [[NSLock alloc] init]; - } - - if (ctx->mtl_device == nil) { - ctx->mtl_device = MTLCreateSystemDefaultDevice(); - - ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; - ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; - - ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; - -#if defined(GGML_METAL_HAS_RESIDENCY_SETS) - ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil; -#endif - - ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; - ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6]; - -#if defined(GGML_METAL_USE_BF16) - ctx->use_bfloat = ctx->has_bfloat; -#else - ctx->use_bfloat = false; -#endif - ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil; - - { - const char * val = getenv("GGML_METAL_FUSION_DEBUG"); - ctx->debug_fusion = val ? atoi(val) : 0; - } - - memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt)); - - ctx->max_size = ctx->mtl_device.maxBufferLength; - - strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1); - } - - ctx->mtl_device_ref_count++; - - return ctx->mtl_device; -} - -// release -static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) { - assert(ctx != NULL); - assert(ctx->mtl_device_ref_count > 0); - - ctx->mtl_device_ref_count--; - - if (ctx->mtl_device_ref_count == 0) { - if (ctx->debug_fusion > 0) { - fprintf(stderr, "%s: fusion stats:\n", __func__); - for (int i = 0; i < GGML_OP_COUNT; i++) { - if (ctx->fuse_cnt[i] == 0) { - continue; - } - - // note: cannot use ggml_log here - fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]); - } - } - - if (ctx->mtl_lock) { - [ctx->mtl_lock release]; - ctx->mtl_lock = nil; - } - - if (ctx->mtl_library) { - [ctx->mtl_library release]; - ctx->mtl_library = nil; - } - - if (ctx->mtl_device) { - [ctx->mtl_device release]; - ctx->mtl_device = nil; - } - } -} - -// kernels - -struct ggml_metal_kernel { - id pipeline; -}; - -enum ggml_metal_kernel_type { - GGML_METAL_KERNEL_TYPE_ADD, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, - GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, - GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, - GGML_METAL_KERNEL_TYPE_SUB, - GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, - GGML_METAL_KERNEL_TYPE_MUL, - GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, - GGML_METAL_KERNEL_TYPE_DIV, - GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, - GGML_METAL_KERNEL_TYPE_ADD_ID, - GGML_METAL_KERNEL_TYPE_REPEAT_F32, - GGML_METAL_KERNEL_TYPE_REPEAT_F16, - GGML_METAL_KERNEL_TYPE_REPEAT_I32, - GGML_METAL_KERNEL_TYPE_REPEAT_I16, - GGML_METAL_KERNEL_TYPE_SCALE, - GGML_METAL_KERNEL_TYPE_SCALE_4, - GGML_METAL_KERNEL_TYPE_CLAMP, - GGML_METAL_KERNEL_TYPE_TANH, - GGML_METAL_KERNEL_TYPE_RELU, - GGML_METAL_KERNEL_TYPE_SIGMOID, - GGML_METAL_KERNEL_TYPE_GELU, - GGML_METAL_KERNEL_TYPE_GELU_4, - GGML_METAL_KERNEL_TYPE_GELU_ERF, - GGML_METAL_KERNEL_TYPE_GELU_ERF_4, - GGML_METAL_KERNEL_TYPE_GELU_QUICK, - GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, - GGML_METAL_KERNEL_TYPE_SILU, - GGML_METAL_KERNEL_TYPE_SILU_4, - GGML_METAL_KERNEL_TYPE_ELU, - GGML_METAL_KERNEL_TYPE_ABS, - GGML_METAL_KERNEL_TYPE_SGN, - GGML_METAL_KERNEL_TYPE_STEP, - GGML_METAL_KERNEL_TYPE_HARDSWISH, - GGML_METAL_KERNEL_TYPE_HARDSIGMOID, - GGML_METAL_KERNEL_TYPE_EXP, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, - GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, - GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, - GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, - GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, - GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, - GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, - GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, - GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, - GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, - GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, - GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, - GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, - GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, - GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, - GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, - GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, - GGML_METAL_KERNEL_TYPE_RMS_NORM, - GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, - GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, - GGML_METAL_KERNEL_TYPE_L2_NORM, - GGML_METAL_KERNEL_TYPE_GROUP_NORM, - GGML_METAL_KERNEL_TYPE_NORM, - GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, - GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, - GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, - GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, - GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, - GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, - GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, - GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, - GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, - GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, - //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, - //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, - //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, - GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, - GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, - GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, - GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, - GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, - GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, - GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, - GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, - GGML_METAL_KERNEL_TYPE_IM2COL_F16, - GGML_METAL_KERNEL_TYPE_IM2COL_F32, - GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, - GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, - GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, - GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, - GGML_METAL_KERNEL_TYPE_UPSCALE_F32, - GGML_METAL_KERNEL_TYPE_PAD_F32, - GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, - GGML_METAL_KERNEL_TYPE_ARANGE_F32, - GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, - GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, - GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, - GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_SET_I32, - GGML_METAL_KERNEL_TYPE_SET_F32, - GGML_METAL_KERNEL_TYPE_CPY_F32_F32, - GGML_METAL_KERNEL_TYPE_CPY_F32_F16, - GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, - GGML_METAL_KERNEL_TYPE_CPY_F16_F16, - GGML_METAL_KERNEL_TYPE_CPY_F16_F32, - GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, - GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, - GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, - GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, - GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, - GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, - GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, - GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, - GGML_METAL_KERNEL_TYPE_CONCAT, - GGML_METAL_KERNEL_TYPE_SQR, - GGML_METAL_KERNEL_TYPE_SQRT, - GGML_METAL_KERNEL_TYPE_SIN, - GGML_METAL_KERNEL_TYPE_COS, - GGML_METAL_KERNEL_TYPE_NEG, - GGML_METAL_KERNEL_TYPE_REGLU, - GGML_METAL_KERNEL_TYPE_GEGLU, - GGML_METAL_KERNEL_TYPE_SWIGLU, - GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, - GGML_METAL_KERNEL_TYPE_GEGLU_ERF, - GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, - GGML_METAL_KERNEL_TYPE_SUM_ROWS, - GGML_METAL_KERNEL_TYPE_MEAN, - GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, - GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, - GGML_METAL_KERNEL_TYPE_ARGMAX, - - GGML_METAL_KERNEL_TYPE_COUNT -}; - -// -// ggml_metal_heap -// - -struct ggml_metal_heap { - // number of times the heap was unused - int n_unused; - - // total number of buffer allocations in this heap across all computes - int64_t n_alloc; - - // current offset in the heap - we reset this after each node in order to reuse the memory - size_t offs; - - // the currently allocated MTLBuffer objects in this heap - id obj; - - NSMutableArray * bufs; -}; - -static struct ggml_metal_heap * ggml_metal_heap_init(id device, size_t size) { - struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap)); - - MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init]; - desc.storageMode = MTLStorageModePrivate; - desc.cpuCacheMode = MTLCPUCacheModeDefaultCache; - desc.type = MTLHeapTypePlacement; - desc.size = size; - - heap->n_unused = 0; - heap->n_alloc = 0; - - heap->obj = [device newHeapWithDescriptor:desc]; - if (!heap->obj) { - GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size); - - free(heap); - - return false; - } - - [desc release]; - - heap->bufs = [[NSMutableArray alloc] init]; - - return heap; -} - -static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) { - heap->offs = 0; - - // count how many graph computes the heap ended up being unused - if ([heap->bufs count] > 0) { - heap->n_unused = 0; - } else { - heap->n_unused++; - } - - for (id buf in heap->bufs) { - [buf release]; - } - [heap->bufs removeAllObjects]; - - // tell the OS that it can reuse this memory if needed - // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc - [heap->obj setPurgeableState:MTLPurgeableStateVolatile]; -} - -static void ggml_metal_heap_free(struct ggml_metal_heap * heap) { - if (heap == nil) { - return; - } - - ggml_metal_heap_reset(heap); - - [heap->obj release]; - [heap->bufs release]; - - free(heap); -} - -@interface ggml_metal_heap_ptr : NSObject - -@property (nonatomic, assign) struct ggml_metal_heap * data; - -@end - -@implementation ggml_metal_heap_ptr -@end - -// -// ggml_metal_mem_pool -// - -struct ggml_metal_mem_pool { - id device; - - int n_heaps; // total number of heaps ever created (including those that were removed) - - NSMutableArray * heaps; - NSMutableArray * heaps_to_remove; -}; - -static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) { - struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool)); - - mem_pool->n_heaps = 0; - - mem_pool->heaps = [[NSMutableArray alloc] init]; - mem_pool->heaps_to_remove = [[NSMutableArray alloc] init]; - - return mem_pool; -} - -static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) { - GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps); - - size_t size_all = 0; - size_t size_cur = 0; - - for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { - GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data); - GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc); - GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused); - GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0); - GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]); - - if ([ptr.data->bufs count] > 0) { - size_cur += [ptr.data->obj size]; - } - size_all += [ptr.data->obj size]; - - ggml_metal_heap_free(ptr.data); - [ptr release]; - } - [mem_pool->heaps release]; - [mem_pool->heaps_to_remove release]; - - if (size_all > 0) { - GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0); - GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0); - } - - free(mem_pool); -} - -static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) { - for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) { - ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i]; - - struct ggml_metal_heap * heap = ptr.data; - ggml_metal_heap_reset(heap); - - // if the heap hasn't been used for a while, remove it - if (heap->n_unused >= 128) { - [mem_pool->heaps_to_remove addObject:@(i)]; - } - } - - if (mem_pool->heaps_to_remove.count > 0) { - // remove in reverse order - for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) { - NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue]; - ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index]; - - struct ggml_metal_heap * heap = ptr.data; - ggml_metal_heap_free(heap); - - [mem_pool->heaps removeObjectAtIndex:index]; - [ptr release]; - - if (i == 0) { - break; - } - } - - [mem_pool->heaps_to_remove removeAllObjects]; - } -} - -static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) { - for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { - ptr.data->offs = 0; - } -} - -static id ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) { - const size_t alignment = 256; - - const size_t size_aligned = GGML_PAD(size, alignment); - - // try one of the existing heaps - for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { - struct ggml_metal_heap * heap = ptr.data; - if (heap->offs + size_aligned <= [heap->obj size]) { - // if this is the first buffer in the heap for the current command buffer, tell the OS that - // it cannot free the memory used by the heap - // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc - if ([heap->bufs count] == 0) { - [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile]; - } - - id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs]; - if (buf == nil) { - GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned); - return nil; - } - - heap->n_alloc++; - heap->offs += size_aligned; - - [heap->bufs addObject:buf]; - - return buf; - } - } - - // create a new heap that can fit this buffer - ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new]; - - struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned); - if (heap == NULL) { - GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned); - return NULL; - } - - //GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]); - - heap_ptr.data = heap; - ggml_metal_heap_reset(heap); - - [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile]; - id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs]; - if (buf == nil) { - GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned); - return NULL; - } - - heap->n_alloc++; - heap->offs += size_aligned; - - [heap->bufs addObject:buf]; - - [mem_pool->heaps addObject:heap_ptr]; - mem_pool->n_heaps++; - - return buf; -} - -struct ggml_metal_command_buffer { - id obj; - - // each command buffer has a memory pool from which it can allocate temporary buffers during the compute - struct ggml_metal_mem_pool * mem_pool; -}; - -struct ggml_backend_metal_context { - id device; - id queue; - - dispatch_queue_t d_queue; - - struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; - - // capture state - bool capture_next_compute; - bool capture_started; - - id capture_scope; - - // command buffer state - int n_cb; // number of extra threads used to submit the command buffers - int n_nodes_0; // number of nodes submitted by the main thread - int n_nodes_1; // remaining number of nodes submitted by the n_cb threads - int n_nodes_per_cb; - - struct ggml_cgraph * gf; - - // the callback given to the thread pool - void (^encode_async)(size_t ith); - - // n_cb command buffers + 1 used by the main thread - struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; - - // abort ggml_metal_graph_compute if callback returns true - ggml_abort_callback abort_callback; - void * abort_callback_data; -}; - -// MSL code -// TODO: move the contents here when ready -// for now it is easier to work in a separate file -// static NSString * const msl_library_source = @"see metal.metal"; - -#if !GGML_METAL_EMBED_LIBRARY -// Here to assist with NSBundle Path Hack -@interface GGMLMetalClass : NSObject -@end -@implementation GGMLMetalClass -@end -#endif - -static void * ggml_metal_host_malloc(size_t n) { - void * data = NULL; - -#if TARGET_OS_OSX - kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE); - if (err != KERN_SUCCESS) { - GGML_LOG_ERROR("%s: error: vm_allocate failed\n", __func__); - return NULL; - } -#else - const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); - if (result != 0) { - GGML_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); - return NULL; - } -#endif - - return data; -} - -// load library -// -// - first check if the library is embedded -// - then check if the library is in the bundle -// - if not found, load the source and compile it -// - if that fails, return NULL -static id ggml_metal_load_library(id device, bool use_bfloat) { - id metal_library = nil; - NSError * error = nil; - NSString * src = nil; - -#if GGML_METAL_EMBED_LIBRARY - GGML_LOG_INFO("%s: using embedded metal library\n", __func__); - - extern const char ggml_metallib_start[]; - extern const char ggml_metallib_end[]; - - src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding]; - -#else - -#ifdef SWIFT_PACKAGE - NSBundle * bundle = SWIFTPM_MODULE_BUNDLE; -#else - NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; -#endif - - NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"]; - if (path_lib == nil) { - // Try to find the resource in the directory where the current binary located. - NSString * current_binary = [[NSProcessInfo processInfo] arguments][0]; - NSString * bin_dir = [current_binary stringByDeletingLastPathComponent]; - NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]]; - if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) { - GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]); - NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error]; - if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) { - // Optionally, if this is a symlink, try to resolve it. - default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error]; - if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) { - // It is a relative path, adding the binary directory as directory prefix. - default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]]; - } - if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) { - // Link to the resource could not be resolved. - default_metallib_path = nil; - } else { - GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]); - } - } - } else { - // The resource couldn't be found in the binary's directory. - default_metallib_path = nil; - } - path_lib = default_metallib_path; - } - - if (path_lib != nil) { - // pre-compiled library found - NSURL * libURL = [NSURL fileURLWithPath:path_lib]; - GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]); - - metal_library = [device newLibraryWithURL:libURL error:&error]; - if (error) { - GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - } else { - GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); - - NSString * path_source; - NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"]; - - GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil"); - - if (path_resource) { - path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"]; - } else { - path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; - } - - if (path_source == nil) { - GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); - path_source = @"ggml-metal.metal"; - } - - GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]); - - src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error]; - if (error) { - GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - } -#endif - - if (!metal_library) { - @autoreleasepool { - // dictionary of preprocessor macros - NSMutableDictionary * prep = [NSMutableDictionary dictionary]; - - if (use_bfloat) { - [prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"]; - } - -#if GGML_METAL_EMBED_LIBRARY - [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"]; -#endif - - MTLCompileOptions * options = [MTLCompileOptions new]; - options.preprocessorMacros = prep; - - //[options setFastMathEnabled:false]; - - metal_library = [device newLibraryWithSource:src options:options error:&error]; - if (error) { - GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - -#if !__has_feature(objc_arc) - [options release]; -#endif - } - } - -#if GGML_METAL_EMBED_LIBRARY - [src release]; -#endif // GGML_METAL_EMBED_LIBRARY - - return metal_library; -} - -static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) { - GGML_LOG_INFO("%s: allocating\n", __func__); - -#if TARGET_OS_OSX && !GGML_METAL_NDEBUG - // Show all the Metal device instances in the system - NSArray * devices = MTLCopyAllDevices(); - for (id device in devices) { - GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]); - } - [devices release]; // since it was created by a *Copy* C method -#endif - - // init context - struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context)); - struct ggml_backend_metal_device_context * ctx_dev = dev->context; - - id device = ctx_dev->mtl_device; - - GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); - - ctx->device = device; - ctx->queue = [device newCommandQueue]; - if (ctx->queue == nil) { - GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); - return NULL; - } - - ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); - - // load library - { - [ctx_dev->mtl_lock lock]; - - if (ctx_dev->mtl_library == nil) { - ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat); - } - - [ctx_dev->mtl_lock unlock]; - } - - id metal_library = ctx_dev->mtl_library; - if (metal_library == nil) { - GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__); - return NULL; - } - - // print MTL GPU family: - GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[device name] UTF8String]); - - // determine max supported GPU family - // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf - // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf - { - for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { - if ([device supportsFamily:i]) { - GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); - break; - } - } - - for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) { - if ([device supportsFamily:i]) { - GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i); - break; - } - } - - for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) { - if ([device supportsFamily:i]) { - GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i); - break; - } - } - } - - GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false"); - GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false"); - GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false"); - GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false"); - GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false"); - GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); - - ctx->capture_next_compute = false; - ctx->capture_started = false; - ctx->capture_scope = nil; - - ctx->gf = nil; - ctx->encode_async = nil; - for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { - ctx->cmd_bufs[i].obj = nil; - - ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init(); - ctx->cmd_bufs[i].mem_pool->device = device; - } - -#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) - if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6); - } -#endif - - // load kernels - { - NSError * error = nil; - - for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { - ctx->kernels[i].pipeline = nil; - } - -#define GGML_METAL_ADD_KERNEL(e, name, supported) \ - if (supported) { \ - struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ - id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ - kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \ - GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ - (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ - (int) kernel->pipeline.threadExecutionWidth); \ - [metal_function release]; \ - if (error) { \ - GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ - return NULL; \ - } \ - } else { \ - GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ - } - - const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm; - const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction; - const bool use_bfloat = ctx_dev->use_bfloat; - - // simd_sum and simd_max requires MTLGPUFamilyApple7 - - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, mul_mv_ext_mxfp4_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, mul_mv_ext_mxfp4_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, mul_mv_ext_mxfp4_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, mul_mv_ext_mxfp4_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); - } - - return ctx; -} - -static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { - GGML_LOG_INFO("%s: deallocating\n", __func__); - - for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { - [ctx->kernels[i].pipeline release]; - } - - Block_release(ctx->encode_async); - - [ctx->queue release]; - - for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { - // ctx->cmd_bufs[i].obj is auto released - - ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool); - } - - dispatch_release(ctx->d_queue); - - free(ctx); -} - -// temporarily defined here for compatibility between ggml-backend and the old API - -struct ggml_backend_metal_buffer { - void * data; - size_t size; - - id metal; -}; - -struct ggml_backend_metal_buffer_context { - void * all_data; - size_t all_size; - bool owned; - - // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap - int n_buffers; - struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; - - // optional MTLResidencySet - id rset; -}; - -// rset init -static bool ggml_backend_metal_buffer_rset_init( - struct ggml_backend_metal_buffer_context * ctx, - struct ggml_backend_metal_device_context * ctx_dev, - id device) { - ctx->rset = nil; - - if (!ctx_dev->has_residency_sets) { - return true; - } - -#if defined(GGML_METAL_HAS_RESIDENCY_SETS) - if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { - MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init]; - desc.label = @"ggml_backend_metal"; - desc.initialCapacity = ctx->n_buffers; - - NSError * error; - ctx->rset = [device newResidencySetWithDescriptor:desc error:&error]; - if (error) { - GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - [desc release]; - return false; - } - - [desc release]; - - for (int i = 0; i < ctx->n_buffers; i++) { - [ctx->rset addAllocation:ctx->buffers[i].metal]; - } - - [ctx->rset commit]; - [ctx->rset requestResidency]; - - return true; - } -#else - GGML_UNUSED(ctx_dev); - GGML_UNUSED(device); -#endif - - return true; -} - -// rset free -static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) { -#if defined(GGML_METAL_HAS_RESIDENCY_SETS) - if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { - if (ctx->rset) { - [ctx->rset endResidency]; - [ctx->rset removeAllAllocations]; - [ctx->rset release]; - } - } -#else - GGML_UNUSED(ctx); -#endif -} - -// finds the Metal buffer that contains the tensor data on the GPU device -// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the -// Metal buffer based on the host memory pointer -// -static id ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) { - //GGML_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); - - const int64_t tsize = ggml_nbytes(t); - - ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; - - struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context; - - // find the view that contains the tensor fully - for (int i = 0; i < buf_ctx->n_buffers; ++i) { - const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data; - - //GGML_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size); - if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) { - *offs = (size_t) ioffs; - - //GGML_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs); - - return buf_ctx->buffers[i].metal; - } - } - - GGML_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name); - - return nil; -} - -static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) { - const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm; - const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction; - const bool use_bfloat = ctx_dev->use_bfloat; - - if (!use_bfloat) { - if (op->type == GGML_TYPE_BF16) { - return false; - } - - for (size_t i = 0, n = 3; i < n; ++i) { - if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { - return false; - } - } - } - - switch (op->op) { - case GGML_OP_UNARY: - switch (ggml_get_unary_op(op)) { - case GGML_UNARY_OP_TANH: - case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_SIGMOID: - case GGML_UNARY_OP_GELU: - case GGML_UNARY_OP_GELU_ERF: - case GGML_UNARY_OP_GELU_QUICK: - case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_ELU: - case GGML_UNARY_OP_NEG: - case GGML_UNARY_OP_ABS: - case GGML_UNARY_OP_SGN: - case GGML_UNARY_OP_STEP: - case GGML_UNARY_OP_HARDSWISH: - case GGML_UNARY_OP_HARDSIGMOID: - case GGML_UNARY_OP_EXP: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; - default: - return false; - } - case GGML_OP_GLU: - switch (ggml_get_glu_op(op)) { - case GGML_GLU_OP_REGLU: - case GGML_GLU_OP_GEGLU: - case GGML_GLU_OP_SWIGLU: - case GGML_GLU_OP_SWIGLU_OAI: - case GGML_GLU_OP_GEGLU_ERF: - case GGML_GLU_OP_GEGLU_QUICK: - return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; - default: - return false; - } - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - case GGML_OP_CONCAT: - return true; - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_ADD_ID: - return op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_ACC: - case GGML_OP_REPEAT: - case GGML_OP_SCALE: - case GGML_OP_CONV_TRANSPOSE_1D: - return true; - case GGML_OP_CLAMP: - return op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_LOG: - return false; // TODO: implement - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - case GGML_OP_SOFT_MAX: - case GGML_OP_GROUP_NORM: - return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); - case GGML_OP_RMS_NORM: - case GGML_OP_L2_NORM: - return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); - case GGML_OP_ARGMAX: - return true; - case GGML_OP_NORM: - return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); - case GGML_OP_ROPE: - return true; - case GGML_OP_IM2COL: - return op->src[0]->type == GGML_TYPE_F16; - case GGML_OP_POOL_1D: - return false; - case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; - case GGML_OP_POOL_2D: - case GGML_OP_PAD: - case GGML_OP_PAD_REFLECT_1D: - case GGML_OP_TIMESTEP_EMBEDDING: - case GGML_OP_ARGSORT: - case GGML_OP_LEAKY_RELU: - return op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_ARANGE: - return true; - case GGML_OP_FLASH_ATTN_EXT: - if (op->src[0]->ne[0] == 32) { - // head size == 32 (e.g. bert-bge-small) - // TODO: not sure if it is worth adding kernels for this size - return false; - } - if (op->src[0]->ne[0] == 576) { - // DeepSeek sizes - // TODO: disabled for now, until optmized - return false; - } - if (op->src[1]->type != op->src[2]->type) { - return false; - } - return has_simdgroup_mm; // TODO: over-restricted for vec-kernels - case GGML_OP_SSM_CONV: - case GGML_OP_SSM_SCAN: - case GGML_OP_RWKV_WKV6: - case GGML_OP_RWKV_WKV7: - return true; - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - return has_simdgroup_reduction && - (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32); - case GGML_OP_CPY: - case GGML_OP_DUP: - case GGML_OP_CONT: - { - switch (op->src[0]->type) { - case GGML_TYPE_F32: - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_BF16: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_IQ4_NL: - return true; - default: - return false; - } - case GGML_TYPE_F16: - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - return true; - default: - return false; - } - case GGML_TYPE_BF16: - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_BF16: - return true; - default: - return false; - } - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - return true; - default: - return false; - } - default: - return false; - }; - } - case GGML_OP_SET: - { - switch (op->src[0]->type) { - case GGML_TYPE_F32: - case GGML_TYPE_I32: - return true; - default: - return false; - }; - } - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_GET_ROWS: - { - return op->ne[3] == 1; - } - case GGML_OP_SET_ROWS: - { - if (op->src[0]->type != GGML_TYPE_F32) { - return false; - } - - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_BF16: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_IQ4_NL: - return true; - default: - return false; - }; - } - default: - return false; - } -} - -static int ggml_metal_encode_node( - ggml_backend_t backend, - int idx, - int idx_end, - id encoder, - struct ggml_metal_mem_pool * mem_pool) { - struct ggml_backend_metal_context * ctx = backend->context; - struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; - - struct ggml_cgraph * gf = ctx->gf; - - enum ggml_op ops[8]; - - struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx; - struct ggml_tensor * node = nodes[0]; - - //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op)); - - struct ggml_tensor * src0 = node->src[0]; - struct ggml_tensor * src1 = node->src[1]; - struct ggml_tensor * src2 = node->src[2]; - struct ggml_tensor * dst = node; - - if (ggml_is_empty(dst)) { - return 1; - } - - switch (dst->op) { - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - { - // noop -> next node - } return 1; - default: - { - } break; - } - - if (!ggml_metal_supports_op(ctx_dev, dst)) { - GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); - GGML_ABORT("unsupported op"); - } - - ggml_metal_mem_pool_clear(mem_pool); - - const int64_t ne00 = src0 ? src0->ne[0] : 0; - const int64_t ne01 = src0 ? src0->ne[1] : 0; - const int64_t ne02 = src0 ? src0->ne[2] : 0; - const int64_t ne03 = src0 ? src0->ne[3] : 0; - - const uint64_t nb00 = src0 ? src0->nb[0] : 0; - const uint64_t nb01 = src0 ? src0->nb[1] : 0; - const uint64_t nb02 = src0 ? src0->nb[2] : 0; - const uint64_t nb03 = src0 ? src0->nb[3] : 0; - - const int64_t ne10 = src1 ? src1->ne[0] : 0; - const int64_t ne11 = src1 ? src1->ne[1] : 0; - const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; - - const uint64_t nb10 = src1 ? src1->nb[0] : 0; - const uint64_t nb11 = src1 ? src1->nb[1] : 0; - const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; - - const int64_t ne20 = src2 ? src2->ne[0] : 0; - const int64_t ne21 = src2 ? src2->ne[1] : 0; - const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); - const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23); - - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; - - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; - - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; - const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; - const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; - - size_t offs_src0 = 0; - size_t offs_src1 = 0; - size_t offs_src2 = 0; - size_t offs_dst = 0; - - id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; - id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; - id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; - id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; - - int n_fuse = 1; - -#if 0 - GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); - if (src0) { - GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, - ggml_is_contiguous(src0), src0->name); - } - if (src1) { - GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, - ggml_is_contiguous(src1), src1->name); - } - if (dst) { - GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, - dst->name); - } -#endif - - id device = ctx_dev->mtl_device; - - switch (dst->op) { - case GGML_OP_CONCAT: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; - - const int32_t dim = ((const int32_t *) dst->op_params)[0]; - - ggml_metal_kargs_concat args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.dim =*/ dim, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous_rows(src0)); - GGML_ASSERT(ggml_is_contiguous_rows(src1)); - - const size_t offs = 0; - - bool bcast_row = false; - - id pipeline = nil; - - ggml_metal_kargs_bin args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.offs =*/ offs, - /*.o1 =*/ { offs_src1 }, - }; - - // c[0] = add(a, b[0]) - // c[1] = add(c[0], b[1]) - // c[2] = add(c[1], b[2]) - // ... - if (ctx_dev->use_fusion) { - ops[0] = GGML_OP_ADD; - ops[1] = GGML_OP_ADD; - ops[2] = GGML_OP_ADD; - ops[3] = GGML_OP_ADD; - ops[4] = GGML_OP_ADD; - ops[5] = GGML_OP_ADD; - ops[6] = GGML_OP_ADD; - ops[7] = GGML_OP_ADD; - - size_t offs_fuse; - id id_fuse; - - // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes - // across splits. idx_end indicates the last node in the current split - for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) { - if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { - break; - } - - if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) { - break; - } - - // b[0] === b[1] === ... - if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) { - break; - } - - // only fuse nodes if src1 is in the same Metal buffer - id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse); - if (id_fuse != id_src1) { - break; - } - - ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++; - - args.o1[n_fuse + 1] = offs_fuse; - } - - ++n_fuse; - - if (ctx_dev->debug_fusion > 1 && n_fuse > 1) { - GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse); - } - } - - if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(src0)); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - switch (dst->op) { - case GGML_OP_ADD: - { - switch (n_fuse) { - case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break; - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break; - case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break; - case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break; - case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - bcast_row = true; - } else { - switch (dst->op) { - case GGML_OP_ADD: - { - switch (n_fuse) { - case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break; - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break; - case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break; - case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break; - case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } - - if (n_fuse > 1) { - id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:0 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - if (bcast_row) { - const int64_t n = ggml_nelements(dst)/4; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } else { - int nth = 32; - - while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - } break; - case GGML_OP_ADD_ID: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - GGML_ASSERT(src2t == GGML_TYPE_I32); - GGML_ASSERT(dstt == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous_rows(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline; - - ggml_metal_kargs_add_id args = { - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb11 =*/ nb11, - /*.nb21 =*/ nb21, - - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_REPEAT: - { - id pipeline; - - switch (src0t) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; - case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - ggml_metal_kargs_repeat args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ACC: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - GGML_ASSERT(dstt == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - const size_t pnb1 = ((const int32_t *) dst->op_params)[0]; - const size_t pnb2 = ((const int32_t *) dst->op_params)[1]; - const size_t pnb3 = ((const int32_t *) dst->op_params)[2]; - const size_t offs = ((const int32_t *) dst->op_params)[3]; - - const bool inplace = (bool) ((const int32_t *) dst->op_params)[4]; - - if (!inplace) { - // run a separete kernel to cpy src->dst - // not sure how to avoid this - // TODO: make a simpler cpy_bytes kernel - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; - - ggml_metal_kargs_cpy args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; - - ggml_metal_kargs_bin args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ pnb1, - /*.nb02 =*/ pnb2, - /*.nb03 =*/ pnb3, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ pnb1, - /*.nb2 =*/ pnb2, - /*.nb3 =*/ pnb3, - /*.offs =*/ offs, - /*.o1 =*/ { offs_src1}, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:0 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SCALE: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - float scale; - float bias; - memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float)); - - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - [encoder setBytes:&bias length:sizeof(bias) atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_CLAMP: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; - - float min; - float max; - memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float)); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(node)) { - // we are not taking into account the strides, so for now require contiguous tensors - GGML_ASSERT(ggml_is_contiguous(src0)); - - case GGML_UNARY_OP_TANH: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_RELU: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SIGMOID: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU_ERF: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU_QUICK: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SILU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_ELU: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ELU].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_NEG: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_ABS: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SGN: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_STEP: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_HARDSWISH: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_HARDSIGMOID: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_EXP: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - default: - { - GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_GLU: - { - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - if (src1) { - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - } - - id pipeline = nil; - - switch (ggml_get_glu_op(node)) { - case GGML_GLU_OP_REGLU: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline; - break; - case GGML_GLU_OP_GEGLU: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline; - break; - case GGML_GLU_OP_SWIGLU: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; - break; - case GGML_GLU_OP_SWIGLU_OAI: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline; - break; - case GGML_GLU_OP_GEGLU_ERF: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline; - break; - case GGML_GLU_OP_GEGLU_QUICK: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline; - break; - default: - GGML_ABORT("fatal error"); - } - - const int32_t swp = ggml_get_op_params_i32(dst, 1); - const float alpha = ggml_get_op_params_f32(dst, 2); - const float limit = ggml_get_op_params_f32(dst, 3); - - const int32_t i00 = swp ? ne0 : 0; - const int32_t i10 = swp ? 0 : ne0; - - ggml_metal_kargs_glu args = { - /*.ne00 =*/ ne00, - /*.nb01 =*/ nb01, - /*.ne10 =*/ src1 ? ne10 : ne00, - /*.nb11 =*/ src1 ? nb11 : nb01, - /*.ne0 =*/ ne0, - /*.nb1 =*/ nb1, - /*.i00 =*/ src1 ? 0 : i00, - /*.i10 =*/ src1 ? 0 : i10, - /*.alpha=*/ alpha, - /*.limit=*/ limit - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - if (src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; - - const int64_t nrows = ggml_nrows(src0); - - const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SQR: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SQRT: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SIN: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_COS: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - { - GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - - id pipeline = nil; - - switch (dst->op) { - case GGML_OP_SUM_ROWS: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; - break; - case GGML_OP_MEAN: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline; - break; - default: - GGML_ABORT("fatal error"); - } - - int nth = 32; // SIMD width - - while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); - nth = MIN(nth, ne00); - - ggml_metal_kargs_sum_rows args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SOFT_MAX: - { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - - int nth = 32; // SIMD width - - id pipeline = nil; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - - if (ne00%4 == 0) { - while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; - } - } else { - while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; - } - } - - float scale; - float max_bias; - - memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); - - const uint32_t n_head = src0->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - -// use this branch to test the ggml_metal_mem_pool functionality -#if 0 - // cpy to tmp buffer in MTLHeap - - id h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0)); - if (!h_src0) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0)); - return 0; - } - - offs_src0 = 0; - - ggml_metal_kargs_cpy args_cpy = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne00, - /*.ne1 =*/ ne01, - /*.ne2 =*/ ne02, - /*.ne3 =*/ ne03, - /*.nb0 =*/ nb00, - /*.nb1 =*/ nb01, - /*.nb2 =*/ nb02, - /*.nb3 =*/ nb03, - }; - - if (src0->type == GGML_TYPE_F16) { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline]; - } else { - [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline]; - } - [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:h_src0 offset:0 atIndex:2]; - - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type)); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)]; - -#else - id h_src0 = id_src0; -#endif - // softmax - - ggml_metal_kargs_soft_max args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.scale =*/ scale, - /*.max_bias =*/ max_bias, - /*.m0 =*/ m0, - /*.m1 =*/ m1, - /*.n_head_log2 =*/ n_head_log2, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:h_src0 offset:offs_src0 atIndex:0]; - if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1]; - } - if (id_src2) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - } else { - [encoder setBuffer:h_src0 offset:offs_src0 atIndex:2]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&args length:sizeof(args) atIndex:4]; - - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_DIAG_MASK_INF: - { - const int n_past = ((const int32_t *)(dst->op_params))[0]; - - id pipeline = nil; - - if (ne00%8 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; - } - - ggml_metal_kargs_diag_mask_inf args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.n_past =*/ n_past, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - if (ne00%8 == 0) { - [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - else { - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - } break; - case GGML_OP_SSM_CONV: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; - - ggml_metal_kargs_ssm_conv args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SSM_SCAN: - { - struct ggml_tensor * src3 = node->src[3]; - struct ggml_tensor * src4 = node->src[4]; - struct ggml_tensor * src5 = node->src[5]; - struct ggml_tensor * src6 = node->src[6]; - - GGML_ASSERT(src3); - GGML_ASSERT(src4); - GGML_ASSERT(src5); - GGML_ASSERT(src6); - - size_t offs_src3 = 0; - size_t offs_src4 = 0; - size_t offs_src5 = 0; - size_t offs_src6 = 0; - - id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; - id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; - id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; - id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; - - const int64_t ne30 = src3->ne[0]; - const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); - - const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30); - const uint64_t nb31 = src3->nb[1]; - - const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); - const int64_t ne41 = src4->ne[1]; - const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); - const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); - - const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40); - const uint64_t nb41 = src4->nb[1]; - const uint64_t nb42 = src4->nb[2]; - const uint64_t nb43 = src4->nb[3]; - - const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); - const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); - const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); - const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); - - const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50); - const uint64_t nb51 = src5->nb[1]; - const uint64_t nb52 = src5->nb[2]; - const uint64_t nb53 = src5->nb[3]; - - const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); - - const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); - - const int64_t d_state = ne00; - const int64_t d_inner = ne01; - const int64_t n_head = ne02; - const int64_t n_group = ne41; - const int64_t n_seq_tokens = ne12; - const int64_t n_seqs = ne13; - - id pipeline = nil; - - if (ne30 == 1) { - // Mamba-2 - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; - } - - ggml_metal_kargs_ssm_scan args = { - /*.d_state =*/ d_state, - /*.d_inner =*/ d_inner, - /*.n_head =*/ n_head, - /*.n_group =*/ n_group, - /*.n_seq_tokens =*/ n_seq_tokens, - /*.n_seqs =*/ n_seqs, - /*.s_off =*/ ggml_nelements(src1) * sizeof(float), - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb21 =*/ nb21, - /*.nb22 =*/ nb22, - /*.nb31 =*/ nb31, - /*.nb41 =*/ nb41, - /*.nb42 =*/ nb42, - /*.nb43 =*/ nb43, - /*.nb51 =*/ nb51, - /*.nb52 =*/ nb52, - /*.nb53 =*/ nb53, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; - [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; - [encoder setBytes:&args length:sizeof(args) atIndex:8]; - - // One shared memory bucket for each simd group in the threadgroup - // NOTE: Metal kernels require the buffer size to be multiple of 16 bytes - // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength - if (d_state >= 32) { - GGML_ASSERT((int64_t)(d_state / 32) <= 32); - const int64_t shmem_size = 32; - GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup); - [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0]; - } - - if (ne30 == 1) { - // Mamba-2 - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; - } else { - GGML_ASSERT(d_inner == 1); - [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; - } - } break; - case GGML_OP_RWKV_WKV6: - { - const int64_t B = dst->src[5]->ne[1]; - const int64_t T = dst->src[0]->ne[2]; - const int64_t C = dst->ne[0]; - const int64_t H = dst->src[0]->ne[1]; - - GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); - GGML_ASSERT(C % H == 0); - GGML_ASSERT(C / H == 64); - - size_t offs_src3 = 0; - size_t offs_src4 = 0; - size_t offs_src5 = 0; - - id id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil; - id id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil; - id id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; - [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - - [encoder setBytes:&B length:sizeof(B) atIndex:7]; - [encoder setBytes:&T length:sizeof(T) atIndex:8]; - [encoder setBytes:&C length:sizeof(C) atIndex:9]; - [encoder setBytes:&H length:sizeof(H) atIndex:10]; - - [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)]; - } break; - case GGML_OP_RWKV_WKV7: - { - const int64_t B = dst->src[6]->ne[1]; - const int64_t T = dst->src[0]->ne[2]; - const int64_t C = dst->ne[0]; - const int64_t H = dst->src[0]->ne[1]; - - GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32); - GGML_ASSERT(C % H == 0); - GGML_ASSERT(C / H == 64); - - size_t offs_src3 = 0; - size_t offs_src4 = 0; - size_t offs_src5 = 0; - size_t offs_src6 = 0; - - id id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil; - id id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil; - id id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil; - id id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; - [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; - - [encoder setBytes:&B length:sizeof(B) atIndex:8]; - [encoder setBytes:&T length:sizeof(T) atIndex:9]; - [encoder setBytes:&C length:sizeof(C) atIndex:10]; - [encoder setBytes:&H length:sizeof(H) atIndex:11]; - - [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)]; - } break; - case GGML_OP_MUL_MAT: - { - GGML_ASSERT(ne00 == ne10); - - GGML_ASSERT(ne12 % ne02 == 0); - GGML_ASSERT(ne13 % ne03 == 0); - - const uint32_t r2 = ne12/ne02; - const uint32_t r3 = ne13/ne03; - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - const int ne11_mm_min = 4; - - // first try to use small-batch mat-mv kernels - // these should be efficient for BS [2, ~8] - if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) && - ( - ( - ( - src0t == GGML_TYPE_F16 || // TODO: helper function - src0t == GGML_TYPE_Q4_0 || - src0t == GGML_TYPE_Q4_1 || - src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || - src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_MXFP4 || - src0t == GGML_TYPE_IQ4_NL || - false) && (ne11 >= 2 && ne11 <= 8) - ) || - ( - ( - src0t == GGML_TYPE_Q4_K || - src0t == GGML_TYPE_Q5_K || - src0t == GGML_TYPE_Q6_K || - false) && (ne11 >= 4 && ne11 <= 8) - ) - ) - ) { - // TODO: determine the optimal parameters based on grid utilization - // I still don't know why we should not always use the maximum available threads: - // - // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32 - // - // my current hypothesis is that the work grid is not evenly divisible for different nsg - // values and there can be some tail effects when nsg is high. need to confirm this - // - const int nsg = 2; // num simdgroups per threadgroup - const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup - const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time) - const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup - int r1ptg = 4; // num src1 rows per threadgroup - - // note: not sure how optimal are those across all different hardware. there might be someting cleverer - switch (ne11) { - case 2: - r1ptg = 2; break; - case 3: - case 6: - r1ptg = 3; break; - case 4: - case 7: - case 8: - r1ptg = 4; break; - case 5: - r1ptg = 5; break; - }; - - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F16: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q4_0: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q4_1: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q5_0: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q5_1: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q8_0: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_MXFP4: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q4_K: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q5_K: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_Q6_K: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - case GGML_TYPE_IQ4_NL: - switch (r1ptg) { - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break; - case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break; - case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break; - default: GGML_ABORT("not implemented"); - } break; - default: GGML_ABORT("not implemented"); - } - - ggml_metal_kargs_mul_mv_ext args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.r2 =*/ r2, - /*.r3 =*/ r3, - /*.nsg =*/ nsg, - /*.nxpsg =*/ nxpsg, - /*.r1ptg =*/ r1ptg, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg); - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } else - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if ([device supportsFamily:MTLGPUFamilyApple7] && - !ggml_is_transposed(src0) && - !ggml_is_transposed(src1) && - src1t == GGML_TYPE_F32 && - ne00 % 32 == 0 && ne00 >= 64 && - (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { - //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - // some Metal matrix data types require aligned pointers - // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; - case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } - - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; - default: GGML_ABORT("MUL MAT-MAT not implemented"); - } - - ggml_metal_kargs_mul_mm args = { - /*.ne00 =*/ ne00, - /*.ne02 =*/ ne02, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.r2 =*/ r2, - /*.r3 =*/ r3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - id pipeline = nil; - - int nsg = 0; // number of simdgroups - int nr0 = 0; // number of src0 rows per simdgroup - int nr1 = 1; // number of src1 rows per threadgroup - - size_t smem = 0; // shared memory - - // use custom matrix x vector kernel - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nsg = 1; - nr0 = 1; - nr1 = 4; - if (ne00 == 4) { - nr0 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; - } - } break; - case GGML_TYPE_F16: - { - nsg = 1; - nr0 = 1; - if (src1t == GGML_TYPE_F32) { - if (ne00 == 4) { - nr0 = 32; - nr1 = 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline; - } else if (ne11 * ne12 < 4) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; - nr1 = ne11; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; - nr1 = 4; - } - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; - nr1 = 4; - } - } break; - case GGML_TYPE_BF16: - { - nsg = 1; - nr0 = 1; - if (src1t == GGML_TYPE_F32) { - if (ne00 == 4) { - nr0 = 32; - nr1 = 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline; - } else if (ne11 * ne12 < 4) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; - nr1 = ne11; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; - nr1 = 4; - } - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; - nr1 = 4; - } - } break; - case GGML_TYPE_Q4_0: - { - nsg = N_SG_Q4_0; - nr0 = N_R0_Q4_0; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; - } break; - case GGML_TYPE_Q4_1: - { - nsg = N_SG_Q4_1; - nr0 = N_R0_Q4_1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; - } break; - case GGML_TYPE_Q5_0: - { - nsg = N_SG_Q5_0; - nr0 = N_R0_Q5_0; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; - } break; - case GGML_TYPE_Q5_1: - { - nsg = N_SG_Q5_1; - nr0 = N_R0_Q5_1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; - } break; - case GGML_TYPE_Q8_0: - { - nsg = N_SG_Q8_0; - nr0 = N_R0_Q8_0; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; - } break; - case GGML_TYPE_MXFP4: - { - nsg = N_SG_MXFP4; - nr0 = N_R0_MXFP4; - smem = 32*sizeof(float); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline; - } break; - case GGML_TYPE_Q2_K: - { - nsg = N_SG_Q2_K; - nr0 = N_R0_Q2_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; - } break; - case GGML_TYPE_Q3_K: - { - nsg = N_SG_Q3_K; - nr0 = N_R0_Q3_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; - } break; - case GGML_TYPE_Q4_K: - { - nsg = N_SG_Q4_K; - nr0 = N_R0_Q4_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; - } break; - case GGML_TYPE_Q5_K: - { - nsg = N_SG_Q5_K; - nr0 = N_R0_Q5_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; - } break; - case GGML_TYPE_Q6_K: - { - nsg = N_SG_Q6_K; - nr0 = N_R0_Q6_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XXS: - { - nsg = N_SG_IQ2_XXS; - nr0 = N_R0_IQ2_XXS; - smem = 256*8+128; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XS: - { - nsg = N_SG_IQ2_XS; - nr0 = N_R0_IQ2_XS; - smem = 512*8+128; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_XXS: - { - nsg = N_SG_IQ3_XXS; - nr0 = N_R0_IQ3_XXS; - smem = 256*4+128; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_S: - { - nsg = N_SG_IQ3_S; - nr0 = N_R0_IQ3_S; - smem = 512*4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; - } break; - case GGML_TYPE_IQ2_S: - { - nsg = N_SG_IQ2_S; - nr0 = N_R0_IQ2_S; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_S: - { - nsg = N_SG_IQ1_S; - nr0 = N_R0_IQ1_S; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_M: - { - nsg = N_SG_IQ1_M; - nr0 = N_R0_IQ1_M; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; - } break; - case GGML_TYPE_IQ4_NL: - { - nsg = N_SG_IQ4_NL; - nr0 = N_R0_IQ4_NL; - smem = 32*sizeof(float); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; - } break; - case GGML_TYPE_IQ4_XS: - { - nsg = N_SG_IQ4_XS; - nr0 = N_R0_IQ4_XS; - smem = 32*sizeof(float); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; - } break; - default: - { - GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t); - GGML_ABORT("not implemented"); - } - }; - - ggml_metal_kargs_mul_mv args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.r2 =*/ r2, - /*.r3 =*/ r3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - if (smem > 0) { - [encoder setThreadgroupMemoryLength:smem atIndex:0]; - } - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } - } break; - case GGML_OP_MUL_MAT_ID: - { - // src2 = ids - GGML_ASSERT(src2t == GGML_TYPE_I32); - - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - - GGML_ASSERT(src1t == GGML_TYPE_F32); - - GGML_ASSERT(ne03 == 1); - GGML_ASSERT(ne13 == 1); - - const uint32_t r2 = 1; - const uint32_t r3 = 1; - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - // ne20 = n_used_experts - // ne21 = n_rows (batch size) - const int ne21_mm_id_min = 32; - - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if ([device supportsFamily:MTLGPUFamilyApple7] && - ne00 % 32 == 0 && ne00 >= 64 && - (ne21 >= ne21_mm_id_min)) { - GGML_ASSERT(ne00 % 4 == 0); - - // some Metal matrix data types require aligned pointers - // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; - case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } - - const int64_t neh10 = ne10; // n_embd - const int64_t neh11 = ne21; // n_tokens - const int64_t neh12 = ne02; // n_expert - - const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16); - const uint64_t nbh11 = nbh10*neh10; - const uint64_t nbh12 = nbh11*neh11; - const uint64_t nbh13 = nbh12*neh12; - - const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12; - id h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1); - if (!h_src1) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1); - return 0; - } - - const int64_t neh0 = ne0; - const int64_t neh1 = ne21; - const int64_t neh2 = ne02; - - const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32); - const uint64_t nbh1 = nbh0*neh0; - const uint64_t nbh2 = nbh1*neh1; - //const uint64_t nbh3 = nbh2*neh2; - - const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2; - id h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst); - if (!h_dst) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst); - return 0; - } - - // tokens per expert - const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02; - id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe); - if (!h_tpe) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe); - return 0; - } - - // id map - // [n_expert_used, n_tokens] - const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21; - id h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids); - if (!h_ids) { - GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids); - return 0; - } - - { - const int nth = MIN(1024, ne10/4); - - ggml_metal_kargs_mul_mm_id_map0 args = { - ne10, - ne11, // n_expert_used (bcast) - nb11, - nb12, - neh11, // n_tokens - nbh11, - ne20, // n_expert_used - nb21, - }; - - id pipeline = nil; - - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer: h_src1 offset:0 atIndex:3]; - [encoder setBuffer: h_tpe offset:0 atIndex:4]; - [encoder setBuffer: h_ids offset:0 atIndex:5]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - - { - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break; - case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break; - default: GGML_ABORT("MUL_MAT_ID not implemented"); - } - - ggml_metal_kargs_mul_mm_id args = { - /*.ne00 =*/ ne00, - /*.ne02 =*/ ne02, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.neh12 =*/ neh12, - /*.nbh10 =*/ nbh10, - /*.nbh11 =*/ nbh11, - /*.nbh12 =*/ nbh12, - /*.nbh13 =*/ nbh13, - /*.neh0 =*/ neh0, - /*.neh1 =*/ neh1, - /*.r2 =*/ r2, - /*.r3 =*/ r3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer: h_src1 offset:0 atIndex:2]; - [encoder setBuffer: h_tpe offset:0 atIndex:3]; - [encoder setBuffer: h_dst offset:0 atIndex:4]; - - [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } - - { - GGML_ASSERT(ne0 % 4 == 0); - - const int nth = MIN(1024, ne0/4); - - ggml_metal_kargs_mul_mm_id_map1 args = { - ne20, // n_expert_used - neh0, - neh1, - nbh1, - nbh2, - ne0, - nb1, - nb2, - }; - - id pipeline = nil; - - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer: h_dst offset:0 atIndex:1]; - [encoder setBuffer: h_ids offset:0 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - } else { - id pipeline = nil; - - int nsg = 0; // number of simdgroups - int nr0 = 0; // number of src0 rows per simdgroup - int nr1 = 1; // number of src1 rows per threadgroup - - size_t smem = 0; // shared memory - - // use custom matrix x vector kernel - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nsg = 1; - nr0 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nsg = 1; - nr0 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; - } break; - case GGML_TYPE_BF16: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nsg = 1; - nr0 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; - } break; - case GGML_TYPE_Q4_0: - { - nsg = N_SG_Q4_0; - nr0 = N_R0_Q4_0; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; - } break; - case GGML_TYPE_Q4_1: - { - nsg = N_SG_Q4_1; - nr0 = N_R0_Q4_1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; - } break; - case GGML_TYPE_Q5_0: - { - nsg = N_SG_Q5_0; - nr0 = N_R0_Q5_0; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; - } break; - case GGML_TYPE_Q5_1: - { - nsg = N_SG_Q5_1; - nr0 = N_R0_Q5_1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; - } break; - case GGML_TYPE_Q8_0: - { - nsg = N_SG_Q8_0; - nr0 = N_R0_Q8_0; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; - } break; - case GGML_TYPE_MXFP4: - { - nsg = N_SG_MXFP4; - nr0 = N_R0_MXFP4; - smem = 32*sizeof(float); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline; - } break; - case GGML_TYPE_Q2_K: - { - nsg = N_SG_Q2_K; - nr0 = N_R0_Q2_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; - } break; - case GGML_TYPE_Q3_K: - { - nsg = N_SG_Q3_K; - nr0 = N_R0_Q3_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; - } break; - case GGML_TYPE_Q4_K: - { - nsg = N_SG_Q4_K; - nr0 = N_R0_Q4_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; - } break; - case GGML_TYPE_Q5_K: - { - nsg = N_SG_Q5_K; - nr0 = N_R0_Q5_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; - } break; - case GGML_TYPE_Q6_K: - { - nsg = N_SG_Q6_K; - nr0 = N_R0_Q6_K; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XXS: - { - nsg = N_SG_IQ2_XXS; - nr0 = N_R0_IQ2_XXS; - smem = 256*8+128; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XS: - { - nsg = N_SG_IQ2_XS; - nr0 = N_R0_IQ2_XS; - smem = 512*8+128; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_XXS: - { - nsg = N_SG_IQ3_XXS; - nr0 = N_R0_IQ3_XXS; - smem = 256*4+128; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_S: - { - nsg = N_SG_IQ3_S; - nr0 = N_R0_IQ3_S; - smem = 512*4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; - } break; - case GGML_TYPE_IQ2_S: - { - nsg = N_SG_IQ2_S; - nr0 = N_R0_IQ2_S; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_S: - { - nsg = N_SG_IQ1_S; - nr0 = N_R0_IQ1_S; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_M: - { - nsg = N_SG_IQ1_M; - nr0 = N_R0_IQ1_M; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; - } break; - case GGML_TYPE_IQ4_NL: - { - nsg = N_SG_IQ4_NL; - nr0 = N_R0_IQ4_NL; - smem = 32*sizeof(float); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; - } break; - case GGML_TYPE_IQ4_XS: - { - nsg = N_SG_IQ4_XS; - nr0 = N_R0_IQ4_XS; - smem = 32*sizeof(float); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; - } break; - default: - { - GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t); - GGML_ABORT("not implemented"); - } - }; - - if (ggml_is_quantized(src0t)) { - GGML_ASSERT(ne00 >= nsg*nr0); - } - - ggml_metal_kargs_mul_mv_id args = { - /*.nei0 =*/ ne20, - /*.nei1 =*/ ne21, - /*.nbi1 =*/ nb21, - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.nb1 =*/ nb1, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; - - const int64_t _ne1 = 1; - const int64_t ne123 = ne20*ne21; - - if (smem > 0) { - [encoder setThreadgroupMemoryLength:smem atIndex:0]; - } - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } - } break; - case GGML_OP_GET_ROWS: - { - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; - case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; - default: GGML_ABORT("not implemented"); - } - - ggml_metal_kargs_get_rows args = { - /*.ne00 =*/ ne00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne10 =*/ ne10, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; - } break; - case GGML_OP_SET_ROWS: - { - id pipeline = nil; - - switch (dst->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break; - default: GGML_ABORT("not implemented"); - } - - const int32_t nk0 = ne0/ggml_blck_size(dst->type); - - int nth = 32; // SIMD width - - while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - int nrptg = 1; - if (nth > nk0) { - nrptg = (nth + nk0 - 1)/nk0; - nth = nk0; - - if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) { - nrptg--; - } - } - - nth = MIN(nth, nk0); - - ggml_metal_kargs_set_rows args = { - /*.nk0 =*/ nk0, - /*.ne01 =*/ ne01, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)]; - } break; - case GGML_OP_RMS_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_rows(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - ggml_metal_kargs_rms_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.eps =*/ eps, - /*.nef1 =*/ { ne01 }, - /*.nef2 =*/ { ne02 }, - /*.nef3 =*/ { ne03 }, - /*.nbf1 =*/ { nb01 }, - /*.nbf2 =*/ { nb02 }, - /*.nbf3 =*/ { nb03 }, - }; - - size_t offs_fuse[2] = { 0, 0 }; - id id_fuse[2] = { id_src0, id_src0 }; - - // d[0] = rms_norm(a) - // d[1] = mul(d[0], b) - // d[2] = add(d[1], c) - if (ctx_dev->use_fusion) { - ops[0] = GGML_OP_RMS_NORM; - ops[1] = GGML_OP_MUL; - ops[2] = GGML_OP_ADD; - - for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) { - if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { - break; - } - - if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) { - break; - } - - if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) { - break; - } - - if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) { - break; - } - - if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) { - break; - } - - ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++; - - id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]); - - args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1]; - args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2]; - args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3]; - - args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1]; - args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2]; - args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3]; - } - - ++n_fuse; - - if (ctx_dev->debug_fusion > 1 && n_fuse > 1) { - if (n_fuse == 2) { - GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__); - } - if (n_fuse == 3) { - GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__); - } - } - } - - if (n_fuse > 1) { - id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst); - } - - id pipeline; - - switch (n_fuse) { - case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break; - case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break; - case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break; - default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse); - } - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); - nth = MIN(nth, ne00/4); - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2]; - [encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_L2_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline; - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); - nth = MIN(nth, ne00/4); - - ggml_metal_kargs_l2_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb01 =*/ nb01, - /*.eps =*/ eps, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_GROUP_NORM: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - float eps; - memcpy(&eps, dst->op_params + 1, sizeof(float)); - - const int32_t n_groups = ((const int32_t *) dst->op_params)[0]; - - int nth = 32; // SIMD width - - //while (nth < ne00/4 && nth < 1024) { - // nth *= 2; - //} - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; - - ggml_metal_kargs_group_norm args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.n_groups =*/ n_groups, - /*.eps =*/ eps, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); - nth = MIN(nth, ne00/4); - - ggml_metal_kargs_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb01 =*/ nb01, - /*.eps =*/ eps, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ROPE: - { - - // make sure we have one or more position id(ne10) per token(ne02) - GGML_ASSERT(ne10 % ne02 == 0); - GGML_ASSERT(ne10 >= ne02); - - const int nth = MIN(1024, ne00); - - const int n_past = ((const int32_t *) dst->op_params)[0]; - const int n_dims = ((const int32_t *) dst->op_params)[1]; - const int mode = ((const int32_t *) dst->op_params)[2]; - // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_ctx_orig = ((const int32_t *) dst->op_params)[4]; - - float freq_base; - float freq_scale; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; - - memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (const int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float)); - - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; - const bool is_vision = mode == GGML_ROPE_TYPE_VISION; - - // mrope - const int sect_0 = ((const int32_t *) dst->op_params)[11]; - const int sect_1 = ((const int32_t *) dst->op_params)[12]; - const int sect_2 = ((const int32_t *) dst->op_params)[13]; - const int sect_3 = ((const int32_t *) dst->op_params)[14]; - - id pipeline = nil; - - if (is_neox) { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } else if (is_mrope && !is_vision) { - GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } else if (is_vision) { - GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } else { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } - - ggml_metal_kargs_rope args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.n_past =*/ n_past, - /*.n_dims =*/ n_dims, - /*.n_ctx_orig =*/ n_ctx_orig, - /*.freq_base =*/ freq_base, - /*.freq_scale =*/ freq_scale, - /*.ext_factor =*/ ext_factor, - /*.attn_factor =*/ attn_factor, - /*.beta_fast =*/ beta_fast, - /*.beta_slow =*/ beta_slow, - /* sect_0 =*/ sect_0, - /* sect_1 =*/ sect_1, - /* sect_2 =*/ sect_2, - /* sect_3 =*/ sect_3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_IM2COL: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int32_t N = src1->ne[is_2D ? 3 : 2]; - const int32_t IC = src1->ne[is_2D ? 2 : 1]; - const int32_t IH = is_2D ? src1->ne[1] : 1; - const int32_t IW = src1->ne[0]; - - const int32_t KH = is_2D ? src0->ne[1] : 1; - const int32_t KW = src0->ne[0]; - - const int32_t OH = is_2D ? dst->ne[2] : 1; - const int32_t OW = dst->ne[1]; - - const int32_t CHW = IC * KH * KW; - - const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; - const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; - - const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup; - - switch (dst->type) { - case GGML_TYPE_F32: { - pipeline = (is_gt_mttpt ? - ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline - : - ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline); - } break; - case GGML_TYPE_F16: { - pipeline = (is_gt_mttpt ? - ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline - : - ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline); - } break; - default: GGML_ABORT("fatal error"); - }; - - ggml_metal_kargs_im2col args = { - /*.ofs0 =*/ ofs0, - /*.ofs1 =*/ ofs1, - /*.IW =*/ IW, - /*.IH =*/ IH, - /*.CHW =*/ CHW, - /*.s0 =*/ s0, - /*.s1 =*/ s1, - /*.p0 =*/ p0, - /*.p1 =*/ p1, - /*.d0 =*/ d0, - /*.d1 =*/ d1, - /*.N =*/ N, - /*.KH =*/ KH, - /*.KW =*/ KW, - /*.KHW =*/ KH * KW, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - if (is_gt_mttpt) { - const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N); - - const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); - - [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; - } else { - [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; - } - } break; - case GGML_OP_CONV_TRANSPOSE_1D: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - - const int32_t IC = src1->ne[1]; - const int32_t IL = src1->ne[0]; - - const int32_t K = src0->ne[0]; - - const int32_t OL = dst->ne[0]; - const int32_t OC = dst->ne[1]; - - id pipeline; - - switch (src0->type) { - case GGML_TYPE_F32: { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline; - } break; - case GGML_TYPE_F16: { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline; - } break; - default: GGML_ABORT("fatal error"); - }; - - ggml_metal_kargs_conv_transpose_1d args = { - /*.IC =*/ IC, - /*.IL =*/ IL, - /*.K =*/ K, - /*.s0 =*/ s0, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_UPSCALE: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const float sf0 = (float)ne0/src0->ne[0]; - const float sf1 = (float)ne1/src0->ne[1]; - const float sf2 = (float)ne2/src0->ne[2]; - const float sf3 = (float)ne3/src0->ne[3]; - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; - - ggml_metal_kargs_upscale args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.sf0 =*/ sf0, - /*.sf1 =*/ sf1, - /*.sf2 =*/ sf2, - /*.sf3 =*/ sf3 - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_PAD: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; - - ggml_metal_kargs_pad args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3 - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_PAD_REFLECT_1D: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const int32_t p0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[1]; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline; - - ggml_metal_kargs_pad_reflect_1d args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.p0 =*/ p0, - /*.p1 =*/ p1 - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARANGE: - { - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - float start; - float step; - - memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&step, ((const int32_t *) dst->op_params) + 2, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; - - ggml_metal_kargs_arange args = { - /*.ne0 =*/ ne0, - /*.start =*/ start, - /*.step =*/ step - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; - [encoder setBytes:&args length:sizeof(args) atIndex:1]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_TIMESTEP_EMBEDDING: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const int dim = dst->op_params[0]; - const int max_period = dst->op_params[1]; - - const int half = dim / 2; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; - - ggml_metal_kargs_timestep_embedding args = { - /*.nb1 =*/ nb1, - /*.dim =*/ dim, - /*.max_period =*/ max_period - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - const int nth = MIN(1024, half); - - [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARGSORT: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_I32); - - const int nrows = ggml_nrows(src0); - - enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - - // bitonic sort requires the number of elements to be power of 2 - int64_t ne00_padded = 1; - while (ne00_padded < ne00) { - ne00_padded *= 2; - } - - // Metal kernels require the buffer size to be multiple of 16 bytes - // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength - const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); - - id pipeline = nil; - - switch (order) { - case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; - case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - - ggml_metal_kargs_argsort args = { - /*.ncols =*/ ne00, - /*.ncols_pad =*/ ne00_padded - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; - } break; - case GGML_OP_LEAKY_RELU: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - float slope; - memcpy(&slope, dst->op_params, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; - - ggml_metal_kargs_leaky_relu args = { - /*.slope =*/ slope - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_FLASH_ATTN_EXT: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ne11 % 32 == 0); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == src2->type); - - //GGML_ASSERT(ggml_are_same_shape (src1, src2)); - GGML_ASSERT(ne11 == ne21); - GGML_ASSERT(ne12 == ne22); - - struct ggml_tensor * src3 = node->src[3]; // mask - struct ggml_tensor * src4 = node->src[4]; // sinks - - size_t offs_src3 = 0; - size_t offs_src4 = 0; - - id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; - id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; - - GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); - GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && - "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); - - const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); - //const int64_t ne31 = src3 ? src3->ne[1] : 0; - const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); - const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); - - const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); - const uint64_t nb31 = src3 ? src3->nb[1] : 0; - const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); - const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); - - float scale; - float max_bias; - float logit_softcap; - memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); - memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); - - if (logit_softcap != 0.0f) { - scale /= logit_softcap; - } - - const uint32_t n_head = src0->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - id pipeline = nil; - - bool use_vec_kernel = false; - - // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) - // for now avoiding mainly to keep the number of templates/kernels a bit lower - // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612 - if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { - switch (src1->type) { - case GGML_TYPE_F16: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_BF16: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q4_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q4_1: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q5_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q5_1: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q8_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - use_vec_kernel = true; - - switch (ne00) { - case 64: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 96: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 128: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 192: - { - if (ne20 == 128) { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } - } break; - case 256: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 576: - { - if (ne20 == 512) { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - GGML_LOG_ERROR("unsupported size: %lld\n", ne20); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - - ggml_metal_kargs_flash_attn_ext args = { - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne11 =*/ ne11, - /*.ne_12_2 =*/ ne12, - /*.ne_12_3 =*/ ne13, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb21 =*/ nb21, - /*.nb22 =*/ nb22, - /*.nb23 =*/ nb23, - /*.ne32 =*/ ne32, - /*.ne33 =*/ ne33, - /*.nb31 =*/ nb31, - /*.nb32 =*/ nb32, - /*.nb33 =*/ nb33, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.scale =*/ scale, - /*.max_bias =*/ max_bias, - /*.m0 =*/ m0, - /*.m1 =*/ m1, - /*.n_head_log2 =*/ n_head_log2, - /*.logit_softcap =*/ logit_softcap, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; - } - if (id_src4) { - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - - if (!use_vec_kernel) { - // half8x8 kernel - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - const int is_q = ggml_is_quantized(src1->type) ? 1 : 0; - - // 2*(2*ncpsg + nqptg)*(nsg) - // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) - // - // 16*32*(nsg) - // the shared memory needed for the simdgroups to load the KV cache - // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG - // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) - - int64_t nsgmax = 2; - - while (true) { - const size_t smem = FATTN_SMEM(nsgmax); - if (smem > device.maxThreadgroupMemoryLength) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - - const size_t smem = FATTN_SMEM(nsg); - - //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); - GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; -#undef FATTN_SMEM - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } else { - // half4x4 kernel - const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 1 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - // ne00 + 2*ncpsg*(nsg) - // for each query, we load it as f16 in shared memory (ne00) - // and store the soft_max values and the mask - // - // ne00*(nsg) - // each simdgroup has a full f32 head vector in shared mem to accumulate results - // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16)) - - int64_t nsgmax = 2; - while (true) { - const size_t smem = FATTN_SMEM(nsgmax); - if (smem > device.maxThreadgroupMemoryLength) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); - - int64_t nsg = 1; - while (nsg <= nsgt) { - nsg *= 2; - } - nsg /= 2; - - const size_t smem = FATTN_SMEM(nsg); - - //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); - GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; -#undef FATTN_SMEM - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } - } break; - case GGML_OP_DUP: - case GGML_OP_CPY: - case GGML_OP_CONT: - { - id pipeline = nil; - - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); - - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_F16: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_BF16: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_Q4_0: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_Q4_1: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_Q5_0: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_Q5_1: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_Q8_0: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - default: GGML_ABORT("not implemented"); - } - - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - - // TODO: support - //const int32_t nk00 = ne00/ggml_blck_size(dst->type); - const int32_t nk00 = ne00; - - int nth = 32; // SIMD width - - while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { - nth *= 2; - } - - nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); - - // when rows are small, we can batch them together in a single threadgroup - int nrptg = 1; - - // TODO: relax this constraint in the future - if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) { - if (nth > nk00) { - nrptg = (nth + nk00 - 1)/nk00; - nth = nk00; - - if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) { - nrptg--; - } - } - } - - nth = MIN(nth, nk00); - - ggml_metal_kargs_cpy args = { - /*.ne00 =*/ nk00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)]; - } break; - case GGML_OP_SET: - { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - - // src0 and dst as viewed during set - const size_t dst_nb0 = ggml_element_size(src0); - - const size_t dst_nb1 = ((int32_t *) dst->op_params)[0]; - const size_t dst_nb2 = ((int32_t *) dst->op_params)[1]; - const size_t dst_nb3 = ((int32_t *) dst->op_params)[2]; - const size_t offset = ((int32_t *) dst->op_params)[3]; - const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - memcpy(((char *) dst->data), ((char *) src0->data), ggml_nbytes(dst)); - } - - const int im0 = (ne10 == 0 ? 0 : ne10-1); - const int im1 = (ne11 == 0 ? 0 : ne11-1); - const int im2 = (ne12 == 0 ? 0 : ne12-1); - const int im3 = (ne13 == 0 ? 0 : ne13-1); - - GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= ggml_nbytes(dst)); - - id pipeline = nil; - - switch (src0t) { - case GGML_TYPE_F32: - GGML_ASSERT(nb10 == sizeof(float)); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break; - case GGML_TYPE_I32: - GGML_ASSERT(nb10 == sizeof(int32_t)); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - ggml_metal_kargs_set args = { - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb1 =*/ dst_nb1, - /*.nb2 =*/ dst_nb2, - /*.nb3 =*/ dst_nb3, - /*.offs =*/ offset, - /*.inplace =*/ inplace, - }; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10); - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_POOL_2D: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt); - - const int32_t * opts = dst->op_params; - enum ggml_op_pool op = opts[0]; - - id pipeline = nil; - switch (src0t) { - case GGML_TYPE_F32: { - switch(op) { - case GGML_OP_POOL_AVG: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break; - case GGML_OP_POOL_MAX: - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break; - default: GGML_ASSERT(false && "not implemented"); - } - } break; - default: GGML_ASSERT(false && "not implemented"); - } - - const int32_t k0 = opts[1]; - const int32_t k1 = opts[2]; - const int32_t s0 = opts[3]; - const int32_t s1 = opts[4]; - const int32_t p0 = opts[5]; - const int32_t p1 = opts[6]; - - const int64_t IH = src0->ne[1]; - const int64_t IW = src0->ne[0]; - - const int64_t N = dst->ne[3]; - const int64_t OC = dst->ne[2]; - const int64_t OH = dst->ne[1]; - const int64_t OW = dst->ne[0]; - - const int64_t parallel_elements = N * OC * OH * OW; - const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); - const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; - - ggml_metal_kargs_pool_2d args_pool_2d = { - /* .k0 = */ k0, - /* .k1 = */ k1, - /* .s0 = */ s0, - /* .s1 = */ s1, - /* .p0 = */ p0, - /* .p1 = */ p1, - /* .IH = */ IH, - /* .IW = */ IW, - /* .OH = */ OH, - /* .OW = */ OW, - /* .parallel_elements = */ parallel_elements - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; - } break; - case GGML_OP_ARGMAX: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - GGML_ASSERT(nb00 == ggml_type_size(src0->type)); - - const int64_t nrows = ggml_nrows(src0); - - int nth = 32; // SIMD width - while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - [encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - default: - { - GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); - } - } - - return n_fuse; -} - -static enum ggml_status ggml_metal_graph_compute( - ggml_backend_t backend, - struct ggml_cgraph * gf) { - struct ggml_backend_metal_context * ctx = backend->context; - struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; - - // number of nodes encoded by the main thread (empirically determined) - const int n_main = 128; - - // number of threads in addition to the main thread - const int n_cb = ctx->n_cb; - - // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them - // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread - // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes - // each thread creates it's own command buffer and enqueues the ops in parallel - // - // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2 - - @autoreleasepool { - ctx->gf = gf; - - ctx->n_nodes_0 = MIN(n_main, gf->n_nodes); - ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0; - - ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; - - const bool should_capture = ctx->capture_next_compute; - if (should_capture) { - ctx->capture_next_compute = false; - - if (!ctx->capture_started) { - // create capture scope - ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device]; - - MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; - descriptor.captureObject = ctx->capture_scope; - descriptor.destination = MTLCaptureDestinationGPUTraceDocument; - descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; - - NSError * error = nil; - if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { - GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); - } else { - [ctx->capture_scope beginScope]; - ctx->capture_started = true; - } - } - } - - // the main thread commits the first few commands immediately - // cmd_buf[n_cb] - { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; - ctx->cmd_bufs[n_cb].obj = cmd_buf; - - [cmd_buf enqueue]; - ctx->encode_async(n_cb); - } - - // prepare the rest of the command buffers asynchronously - // cmd_buf[0.. n_cb) - for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; - ctx->cmd_bufs[cb_idx].obj = cmd_buf; - - // always enqueue the first two command buffers - // enqueue all of the command buffers if we don't need to abort - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [cmd_buf enqueue]; - } - } - - dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); - - // wait for completion and check status of each command buffer - // needed to detect if the device ran out-of-memory for example (#1881) - { - id cmd_buf = ctx->cmd_bufs[n_cb].obj; - [cmd_buf waitUntilCompleted]; - - MTLCommandBufferStatus status = [cmd_buf status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); - if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); - } - - return GGML_STATUS_FAILED; - } - } - - for (int i = 0; i < n_cb; ++i) { - id cmd_buf = ctx->cmd_bufs[i].obj; - [cmd_buf waitUntilCompleted]; - - MTLCommandBufferStatus status = [cmd_buf status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); - } - - return GGML_STATUS_FAILED; - } - - id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); - if (!next_buffer) { - continue; - } - - const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); - if (next_queued) { - continue; - } - - if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { - GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); - return GGML_STATUS_ABORTED; - } - - [next_buffer commit]; - } - - if (!should_capture && ctx->capture_started) { - [ctx->capture_scope endScope]; - [[MTLCaptureManager sharedCaptureManager] stopCapture]; - } - } - - return GGML_STATUS_SUCCESS; -} - -//////////////////////////////////////////////////////////////////////////////// - -// backend interface - -static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) { - struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; - - for (int i = 0; i < ctx->n_buffers; i++) { - [ctx->buffers[i].metal release]; - } - - ggml_backend_metal_buffer_rset_free(ctx); - - if (ctx->owned) { -#if TARGET_OS_OSX - vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size); -#else - free(ctx->all_data); -#endif - } - - free(ctx); -} - -static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) { - struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; - - return ctx->all_data; -} - -static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { - memset((char *)tensor->data + offset, value, size); - - GGML_UNUSED(buffer); -} - -static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - memcpy((char *)tensor->data + offset, data, size); - - GGML_UNUSED(buffer); -} - -static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - memcpy(data, (const char *)tensor->data + offset, size); - - GGML_UNUSED(buffer); -} - -static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { - if (ggml_backend_buffer_is_host(src->buffer)) { - memcpy(dst->data, src->data, ggml_nbytes(src)); - return true; - } - return false; - - GGML_UNUSED(buffer); -} - -static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; - - memset(ctx->all_data, value, ctx->all_size); -} - -static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_get_base, - /* .init_tensor = */ NULL, - /* .memset_tensor = */ ggml_backend_metal_buffer_memset_tensor, - /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor, - /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_clear, - /* .reset = */ NULL, -}; - -// default buffer type - -static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) { - return "Metal"; - - GGML_UNUSED(buft); -} - -static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { -#ifndef GGML_METAL_NDEBUG -#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) - if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n", - __func__, - size_aligned / 1024.0 / 1024.0, - device.currentAllocatedSize / 1024.0 / 1024.0, - device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); - - if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) { - GGML_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); - } - } else { - GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", - __func__, - size_aligned / 1024.0 / 1024.0, - device.currentAllocatedSize / 1024.0 / 1024.0); - } -#endif -#endif - GGML_UNUSED(device); - GGML_UNUSED(size_aligned); -} - -static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); - - const size_t size_page = sysconf(_SC_PAGESIZE); - - size_t size_aligned = size; - if ((size_aligned % size_page) != 0) { - size_aligned += (size_page - (size_aligned % size_page)); - } - - struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context; - - GGML_ASSERT(ctx_dev->mtl_device != nil); - - id device = ctx_dev->mtl_device; - - ctx->all_data = ggml_metal_host_malloc(size_aligned); - ctx->all_size = size_aligned; - ctx->owned = true; - ctx->n_buffers = 1; - - if (ctx->all_data != NULL) { - ctx->buffers[0].data = ctx->all_data; - ctx->buffers[0].size = size; - ctx->buffers[0].metal = nil; - - if (size_aligned > 0) { - ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data - length:size_aligned - options:MTLResourceStorageModeShared - deallocator:nil]; - } - } - - if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); - free(ctx); - return NULL; - } - - if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { - GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); - free(ctx); - return NULL; - } - - //ggml_backend_metal_log_allocated_size(device, size_aligned); - - return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); -} - -static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 32; - - GGML_UNUSED(buft); -} - -static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { - const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size; - - return max_size; -} - -static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { - return true; - - GGML_UNUSED(buft); -} - -ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ ggml_backend_metal_buffer_type_is_host, - }, - /* .device = */ &g_ggml_backend_metal_device, - /* .context = */ NULL, - }; - - return &ggml_backend_buffer_type_metal; -} - -static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) { - return "Metal_Mapped"; - - GGML_UNUSED(buft); -} - -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) { - static struct ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_from_ptr_type_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ ggml_backend_metal_buffer_type_is_host, - }, - /* .device = */ &g_ggml_backend_metal_device, - /* .context = */ NULL, - }; - - return &ggml_backend_buffer_from_ptr_type_metal; -} - -// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr -ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) { - struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); - - ctx->all_data = data; - ctx->all_size = size; - ctx->owned = false; - ctx->n_buffers = 0; - - const size_t size_page = sysconf(_SC_PAGESIZE); - - // page-align the data ptr - { - const uintptr_t offs = (uintptr_t) data % size_page; - data = (void *) ((char *) data - offs); - size += offs; - } - - size_t size_aligned = size; - if ((size_aligned % size_page) != 0) { - size_aligned += (size_page - (size_aligned % size_page)); - } - - struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main; - - GGML_ASSERT(ctx_dev->mtl_device != nil); - - id device = ctx_dev->mtl_device; - - // the buffer fits into the max buffer size allowed by the device - if (size_aligned <= device.maxBufferLength) { - ctx->buffers[ctx->n_buffers].data = data; - ctx->buffers[ctx->n_buffers].size = size; - ctx->buffers[ctx->n_buffers].metal = nil; - - if (size_aligned > 0) { - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); - return false; - } - } - - ggml_backend_metal_log_allocated_size(device, size_aligned); - - ++ctx->n_buffers; - } else { - // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into - // one of the views - const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case - const size_t size_step = device.maxBufferLength - size_ovlp; - const size_t size_view = device.maxBufferLength; - - for (size_t i = 0; i < size; i += size_step) { - const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); - - ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); - ctx->buffers[ctx->n_buffers].size = size_step_aligned; - ctx->buffers[ctx->n_buffers].metal = nil; - - if (size_step_aligned > 0) { - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); - return false; - } - } - - ggml_backend_metal_log_allocated_size(device, size_step_aligned); - - if (i + size_step < size) { - GGML_LOG_INFO("\n"); - } - - ++ctx->n_buffers; - } - } - - if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { - GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); - free(ctx); - return NULL; - } - - return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); -} - -// backend - -static const char * ggml_backend_metal_name(ggml_backend_t backend) { - return "Metal"; - - GGML_UNUSED(backend); -} - -static void ggml_backend_metal_free(ggml_backend_t backend) { - struct ggml_backend_metal_context * ctx = backend->context; - - ggml_metal_free(ctx); - - free(backend); -} - -static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - return ggml_metal_graph_compute(backend, cgraph); -} - -static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { - GGML_ASSERT(ggml_backend_is_metal(backend)); - - struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; - - if (ctx->n_cb != n_cb) { - ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); - - if (ctx->n_cb > 2) { - GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb); - } - } - - if (ctx->encode_async) { - Block_release(ctx->encode_async); - } - - ctx->encode_async = Block_copy(^(size_t iter) { - const int cb_idx = iter; - const int n_cb_l = ctx->n_cb; - - const int n_nodes_0 = ctx->n_nodes_0; - const int n_nodes_1 = ctx->n_nodes_1; - - const int n_nodes_per_cb = ctx->n_nodes_per_cb; - - id cmd_buf = ctx->cmd_bufs[cb_idx].obj; - - id encoder = [cmd_buf computeCommandEncoder]; - - int node_start = 0; - int node_end = n_nodes_0; - - if (cb_idx < n_cb_l) { - node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); - node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); - } - - const bool should_capture = ctx->capture_next_compute; - - struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; - ggml_metal_mem_pool_reset(mem_pool); - - for (int idx = node_start; idx < node_end;) { - if (should_capture) { - [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; - } - - const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool); - if (idx + res > node_end) { - GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s", - "https://github.com/ggml-org/llama.cpp/pull/14849"); - } - - if (should_capture) { - [encoder popDebugGroup]; - } - - if (res == 0) { - break; - } - - idx += res; - } - - [encoder endEncoding]; - - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [cmd_buf commit]; - } - }); -} - -static struct ggml_backend_i ggml_backend_metal_i = { - /* .get_name = */ ggml_backend_metal_name, - /* .free = */ ggml_backend_metal_free, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_metal_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, -}; - -static ggml_guid_t ggml_backend_metal_guid(void) { - static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 }; - return &guid; -} - -// TODO: remove in the future -ggml_backend_t ggml_backend_metal_init(void) { - ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0); - - struct ggml_backend_metal_context * ctx = ggml_metal_init(dev); - if (ctx == NULL) { - GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); - return NULL; - } - - ggml_backend_t backend = malloc(sizeof(struct ggml_backend)); - - *backend = (struct ggml_backend) { - /* .guid = */ ggml_backend_metal_guid(), - /* .interface = */ ggml_backend_metal_i, - /* .device = */ dev, - /* .context = */ ctx, - }; - - ggml_backend_metal_set_n_cb(backend, 1); - - return backend; -} - -bool ggml_backend_is_metal(ggml_backend_t backend) { - return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid()); -} - -void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) { - GGML_ASSERT(ggml_backend_is_metal(backend)); - - struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; - - ctx->abort_callback = abort_callback; - ctx->abort_callback_data = user_data; -} - -bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { - GGML_ASSERT(ggml_backend_is_metal(backend)); - - struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; - - GGML_ASSERT(ctx_dev->mtl_device != nil); - - return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; -} - -void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { - GGML_ASSERT(ggml_backend_is_metal(backend)); - - struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; - ctx->capture_next_compute = true; -} - -// backend device - -static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) { - return "Metal"; - - GGML_UNUSED(dev); -} - -static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) { - struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; - - return ctx_dev->name; -} - -static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - if (@available(macOS 10.12, iOS 16.0, *)) { - struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; - id device = ctx_dev->mtl_device; - - *total = device.recommendedMaxWorkingSetSize; - *free = *total - device.currentAllocatedSize; - } else { - *free = 1; - *total = 1; - } -} - -static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) { - return GGML_BACKEND_DEVICE_TYPE_GPU; - - GGML_UNUSED(dev); -} - -static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { - props->name = ggml_backend_metal_device_get_name(dev); - props->description = ggml_backend_metal_device_get_description(dev); - props->type = ggml_backend_metal_device_get_type(dev); - ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); - props->caps = (struct ggml_backend_dev_caps) { - /* .async = */ false, - /* .host_buffer = */ false, - /* .buffer_from_host_ptr = */ true, - /* .events = */ false, - }; -} - -static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) { - struct ggml_backend_metal_context * ctx = ggml_metal_init(dev); - if (ctx == NULL) { - GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); - return NULL; - } - - ggml_backend_t backend = malloc(sizeof(struct ggml_backend)); - - *backend = (struct ggml_backend) { - /* .guid = */ ggml_backend_metal_guid(), - /* .interface = */ ggml_backend_metal_i, - /* .device = */ dev, - /* .context = */ ctx, - }; - - ggml_backend_metal_set_n_cb(backend, 1); - - return backend; - - GGML_UNUSED(params); -} - -static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) { - return ggml_backend_metal_buffer_type(); - - GGML_UNUSED(dev); -} - -static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { - struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); - - ctx->all_data = ptr; - ctx->all_size = size; - ctx->owned = false; - ctx->n_buffers = 0; - - const size_t size_page = sysconf(_SC_PAGESIZE); - - // page-align the data ptr - { - const uintptr_t offs = (uintptr_t) ptr % size_page; - ptr = (void *) ((char *) ptr - offs); - size += offs; - } - - size_t size_aligned = size; - if ((size_aligned % size_page) != 0) { - size_aligned += (size_page - (size_aligned % size_page)); - } - - struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; - - GGML_ASSERT(ctx_dev->mtl_device != nil); - - id device = ctx_dev->mtl_device; - - // the buffer fits into the max buffer size allowed by the device - if (size_aligned <= device.maxBufferLength) { - ctx->buffers[ctx->n_buffers].data = ptr; - ctx->buffers[ctx->n_buffers].size = size; - ctx->buffers[ctx->n_buffers].metal = nil; - - if (size_aligned > 0) { - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); - return false; - } - } - - ggml_backend_metal_log_allocated_size(device, size_aligned); - - ++ctx->n_buffers; - } else { - // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into - // one of the views - const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case - const size_t size_step = device.maxBufferLength - size_ovlp; - const size_t size_view = device.maxBufferLength; - - for (size_t i = 0; i < size; i += size_step) { - const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); - - ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) ptr + i); - ctx->buffers[ctx->n_buffers].size = size_step_aligned; - ctx->buffers[ctx->n_buffers].metal = nil; - - if (size_step_aligned > 0) { - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); - return false; - } - } - - ggml_backend_metal_log_allocated_size(device, size_step_aligned); - - if (i + size_step < size) { - GGML_LOG_INFO("\n"); - } - - ++ctx->n_buffers; - } - } - - if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { - GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); - free(ctx); - return NULL; - } - - return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); -} - -static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - struct ggml_backend_metal_device_context * ctx_dev = dev->context; - - return ggml_metal_supports_op(ctx_dev, op); -} - -static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - return - buft->iface.get_name == ggml_backend_metal_buffer_type_get_name || - buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name; - - GGML_UNUSED(dev); -} - -static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - return false; - - GGML_UNUSED(dev); - GGML_UNUSED(op); -} - -static struct ggml_backend_device_i ggml_backend_metal_device_i = { - /* .get_name = */ ggml_backend_metal_device_get_name, - /* .get_description = */ ggml_backend_metal_device_get_description, - /* .get_memory = */ ggml_backend_metal_device_get_memory, - /* .get_type = */ ggml_backend_metal_device_get_type, - /* .get_props = */ ggml_backend_metal_device_get_props, - /* .init_backend = */ ggml_backend_metal_device_init, - /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type, - /* .get_host_buffer_type = */ NULL, - /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr, - /* .supports_op = */ ggml_backend_metal_device_supports_op, - /* .supports_buft = */ ggml_backend_metal_device_supports_buft, - /* .offload_op = */ ggml_backend_metal_device_offload_op, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, -}; - -// backend registry - -static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) { - return "Metal"; - - GGML_UNUSED(reg); -} - -static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) { - return 1; - - GGML_UNUSED(reg); -} - -static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) { - GGML_ASSERT(index == 0); - - return &g_ggml_backend_metal_device; - - GGML_UNUSED(reg); - GGML_UNUSED(index); -} - -static struct ggml_backend_feature g_ggml_backend_metal_features[] = { -#if defined(GGML_METAL_EMBED_LIBRARY) - { "EMBED_LIBRARY", "1" }, -#endif -#if defined(GGML_METAL_USE_BF16) - { "BF16", "1" }, -#endif - { nil, nil }, -}; - -static struct ggml_backend_feature * ggml_backend_metal_get_features(ggml_backend_reg_t reg) { - return g_ggml_backend_metal_features; - - GGML_UNUSED(reg); -} - -static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) { - if (strcmp(name, "ggml_backend_get_features") == 0) { - return (void *)ggml_backend_metal_get_features; - } - - return NULL; - - GGML_UNUSED(reg); -} -static struct ggml_backend_reg_i ggml_backend_metal_reg_i = { - /* .get_name = */ ggml_backend_metal_reg_get_name, - /* .device_count = */ ggml_backend_metal_reg_device_count, - /* .device_get = */ ggml_backend_metal_reg_device_get, - /* .get_proc_address = */ ggml_backend_metal_get_proc_address, -}; - -// called upon program exit -static void ggml_metal_cleanup(void) { - ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main); -} - -// TODO: make thread-safe -ggml_backend_reg_t ggml_backend_metal_reg(void) { - ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main); - - // register cleanup callback - // TODO: not ideal, but not sure if there is a better way to do this in Objective-C - atexit(ggml_metal_cleanup); - - { - g_ggml_backend_metal_reg = (struct ggml_backend_reg) { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_metal_reg_i, - /* .context = */ NULL, - }; - - g_ggml_backend_metal_device = (struct ggml_backend_device) { - /* .iface = */ ggml_backend_metal_device_i, - /* .reg = */ &g_ggml_backend_metal_reg, - /* .context = */ &g_ggml_ctx_dev_main, - }; - } - - return &g_ggml_backend_metal_reg; -} - -GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index b35a3bbdc31..05ca9842d86 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -15,6 +15,10 @@ using namespace metal; #define MIN(x, y) ((x) < (y) ? (x) : (y)) #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } +#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1)) + +#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x) + #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf @@ -23,11 +27,11 @@ using namespace metal; // .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/ggml-metal.metal // .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal // -#if __METAL_VERSION__ < 310 && defined(GGML_METAL_USE_BF16) -#undef GGML_METAL_USE_BF16 +#if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16) +#undef GGML_METAL_HAS_BF16 #endif -#if defined(GGML_METAL_USE_BF16) +#if defined(GGML_METAL_HAS_BF16) typedef matrix bfloat4x4; #endif @@ -68,6 +72,11 @@ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) reg = (type4x4)(*src); } +template +void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) { + reg = (type4)(*src); +} + template void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { reg = (type4x4)(*src); @@ -78,7 +87,7 @@ void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { reg = (type4)(*(src)); } -#if defined(GGML_METAL_USE_BF16) +#if defined(GGML_METAL_HAS_BF16) template void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { reg = (type4x4)(*src); @@ -919,7 +928,7 @@ kernel void kernel_add_fuse_impl( typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; -template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; +template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; @@ -928,7 +937,7 @@ template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_ template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; -kernel void kernel_sub( +kernel void kernel_sub_fuse_1( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -954,7 +963,7 @@ kernel void kernel_sub( } } -kernel void kernel_mul( +kernel void kernel_mul_fuse_1( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -974,13 +983,20 @@ kernel void kernel_mul( device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + if (args.ne10 == 1) { + const float x = *((device float *)(src1_ptr)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + } } } -kernel void kernel_div( +kernel void kernel_div_fuse_1( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -1000,9 +1016,16 @@ kernel void kernel_div( device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + if (args.ne10 == 1) { + const float x = 1.0f / *((device float *)(src1_ptr)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + } } } @@ -1073,23 +1096,17 @@ kernel void kernel_add_row_c4_fuse_impl( device const char * src1, device char * dst, uint tpig[[thread_position_in_grid]]) { - const uint nb = args.ne00/4; const uint i = tpig % nb; device const float4 * src0_row = (device const float4 *) (src0); device float4 * dst_row = (device float4 *) (dst); - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - float4 res = src0_row[tpig]; #pragma unroll(F) for (short j = 0; j < F; ++j) { - res += src1_row[j][i]; + res += ((device const float4 *) (src1 + args.o1[j]))[i]; } dst_row[tpig] = res; @@ -1097,7 +1114,7 @@ kernel void kernel_add_row_c4_fuse_impl( typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; -template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; +template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; @@ -1137,7 +1154,7 @@ kernel void kernel_sub_row_c4_fuse_impl( typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; -template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; +template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; template kernel void kernel_mul_row_c4_fuse_impl( @@ -1170,7 +1187,7 @@ kernel void kernel_mul_row_c4_fuse_impl( typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; -template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; +template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; template kernel void kernel_div_row_c4_fuse_impl( @@ -1203,55 +1220,80 @@ kernel void kernel_div_row_c4_fuse_impl( typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; -template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; +template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; -kernel void kernel_scale( +kernel void kernel_scale_f32( + constant ggml_metal_kargs_scale & args, device const float * src0, device float * dst, - constant float & scale, - constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale + bias; + dst[tpig] = src0[tpig] * args.scale + args.bias; } -kernel void kernel_scale_4( +kernel void kernel_scale_f32_4( + constant ggml_metal_kargs_scale & args, device const float4 * src0, device float4 * dst, - constant float & scale, - constant float & bias, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale + bias; + dst[tpig] = src0[tpig] * args.scale + args.bias; } -kernel void kernel_clamp( +kernel void kernel_clamp_f32( + constant ggml_metal_kargs_clamp & args, device const float * src0, device float * dst, - constant float & min, - constant float & max, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); + dst[tpig] = clamp(src0[tpig], args.min, args.max); +} + +kernel void kernel_clamp_f32_4( + constant ggml_metal_kargs_clamp & args, + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = clamp(src0[tpig], args.min, args.max); } -kernel void kernel_relu( +kernel void kernel_relu_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = max(0.0f, src0[tpig]); } -kernel void kernel_sigmoid( +kernel void kernel_relu_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); +} + +kernel void kernel_sigmoid_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); } -kernel void kernel_tanh( +kernel void kernel_sigmoid_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); +} + +kernel void kernel_tanh_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = precise::tanh(x); + dst[tpig] = precise::tanh(src0[tpig]); +} + +kernel void kernel_tanh_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = precise::tanh(src0[tpig]); } constant float GELU_COEF_A = 0.044715f; @@ -1259,7 +1301,7 @@ constant float GELU_QUICK_COEF = -1.702f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; constant float SQRT_2_INV = 0.70710678118654752440084436210484f; -kernel void kernel_gelu( +kernel void kernel_gelu_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { @@ -1268,7 +1310,7 @@ kernel void kernel_gelu( dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } -kernel void kernel_gelu_4( +kernel void kernel_gelu_f32_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -1281,7 +1323,7 @@ kernel void kernel_gelu_4( dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } -kernel void kernel_gelu_quick( +kernel void kernel_gelu_quick_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { @@ -1290,7 +1332,7 @@ kernel void kernel_gelu_quick( dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); } -kernel void kernel_gelu_quick_4( +kernel void kernel_gelu_quick_f32_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -1317,7 +1359,7 @@ T erf_approx(T x) { return sign_x * y; } -kernel void kernel_gelu_erf( +kernel void kernel_gelu_erf_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { @@ -1326,7 +1368,7 @@ kernel void kernel_gelu_erf( dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); } -kernel void kernel_gelu_erf_4( +kernel void kernel_gelu_erf_f32_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -1335,7 +1377,7 @@ kernel void kernel_gelu_erf_4( dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); } -kernel void kernel_silu( +kernel void kernel_silu_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { @@ -1343,7 +1385,7 @@ kernel void kernel_silu( dst[tpig] = x / (1.0f + exp(-x)); } -kernel void kernel_silu_4( +kernel void kernel_silu_f32_4( device const float4 * src0, device float4 * dst, uint tpig[[thread_position_in_grid]]) { @@ -1351,99 +1393,202 @@ kernel void kernel_silu_4( dst[tpig] = x / (1.0f + exp(-x)); } -kernel void kernel_elu( +kernel void kernel_elu_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; + const float x = src0[tpig]; dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f); } -kernel void kernel_sqr( +kernel void kernel_elu_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const float4 x = src0[tpig]; + dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); + dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); + dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); + dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); +} + +kernel void kernel_sqr_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * src0[tpig]; } -kernel void kernel_sqrt( +kernel void kernel_sqr_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + +kernel void kernel_sqrt_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = sqrt(src0[tpig]); } -kernel void kernel_sin( +kernel void kernel_sqrt_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sqrt(src0[tpig]); +} + +kernel void kernel_sin_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = sin(src0[tpig]); } -kernel void kernel_cos( +kernel void kernel_sin_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sin(src0[tpig]); +} + +kernel void kernel_cos_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = cos(src0[tpig]); } -kernel void kernel_neg( +kernel void kernel_cos_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = cos(src0[tpig]); +} + +kernel void kernel_log_f32( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = log(src0[tpig]); +} + +kernel void kernel_log_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = log(src0[tpig]); +} + +kernel void kernel_neg_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = -src0[tpig]; } -kernel void kernel_abs( +kernel void kernel_neg_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = -src0[tpig]; +} + +kernel void kernel_abs_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = fabs(src0[tpig]); } -kernel void kernel_sgn( +kernel void kernel_abs_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = fabs(src0[tpig]); +} + +kernel void kernel_sgn_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f); + dst[tpig] = sign(src0[tpig]); +} + +kernel void kernel_sgn_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sign(src0[tpig]); } -kernel void kernel_step( +kernel void kernel_step_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f; + dst[tpig] = step(0.0f, src0[tpig]); +} + +kernel void kernel_step_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = step(0.0f, src0[tpig]); } -kernel void kernel_hardswish( +kernel void kernel_hardswish_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; + const float x = src0[tpig]; dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); } -kernel void kernel_hardsigmoid( +kernel void kernel_hardswish_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const float4 x = src0[tpig]; + dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_hardsigmoid_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; + const float x = src0[tpig]; + dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_hardsigmoid_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const float4 x = src0[tpig]; dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); } -kernel void kernel_exp( +kernel void kernel_exp_f32( device const float * src0, device float * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = exp(src0[tpig]); } -kernel void kernel_reglu( +kernel void kernel_exp_f32_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]); +} + +kernel void kernel_reglu_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -1459,11 +1604,11 @@ kernel void kernel_reglu( } } -kernel void kernel_geglu( +kernel void kernel_geglu_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -1481,11 +1626,11 @@ kernel void kernel_geglu( } } -kernel void kernel_swiglu( +kernel void kernel_swiglu_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -1503,11 +1648,11 @@ kernel void kernel_swiglu( } } -kernel void kernel_swiglu_oai( +kernel void kernel_swiglu_oai_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -1529,11 +1674,11 @@ kernel void kernel_swiglu_oai( } } -kernel void kernel_geglu_erf( +kernel void kernel_geglu_erf_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -1551,11 +1696,11 @@ kernel void kernel_geglu_erf( } } -kernel void kernel_geglu_quick( +kernel void kernel_geglu_quick_f32( + constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_glu & args, uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { @@ -1625,16 +1770,16 @@ kernel void kernel_sum_rows( typedef decltype(kernel_sum_rows) kernel_sum_rows_t; -template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows; -template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows; +template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; template kernel void kernel_soft_max( + constant ggml_metal_kargs_soft_max & args, device const char * src0, device const char * src1, device const char * src2, device char * dst, - constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -1736,11 +1881,11 @@ kernel void kernel_soft_max( template kernel void kernel_soft_max_4( + constant ggml_metal_kargs_soft_max & args, device const char * src0, device const char * src1, device const char * src2, device char * dst, - constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -1850,53 +1995,12 @@ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kerne template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; -kernel void kernel_diag_mask_inf( - device const float * src0, - device float * dst, - constant ggml_metal_kargs_diag_mask_inf & args, - uint3 tpig[[thread_position_in_grid]]) { - const int64_t i02 = tpig[2]; - const int64_t i01 = tpig[1]; - const int64_t i00 = tpig[0]; - - if (i00 > args.n_past + i01) { - dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY; - } else { - dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00]; - } -} - -kernel void kernel_diag_mask_inf_8( - device const float4 * src0, - device float4 * dst, - constant ggml_metal_kargs_diag_mask_inf & args, - uint3 tpig[[thread_position_in_grid]]) { - - const int64_t i = 2*tpig[0]; - - dst[i+0] = src0[i+0]; - dst[i+1] = src0[i+1]; - int64_t i4 = 4*i; - const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01; - const int64_t i01 = i4/(args.ne00); i4 -= i01*args.ne00; - const int64_t i00 = i4; - for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= args.n_past + i01) { - break; - } - dst[i+1][k] = -INFINITY; - if (i00 + k > args.n_past + i01) { - dst[i][k] = -INFINITY; - } - } -} - // ref: ggml.c:ggml_compute_forward_ssm_conv_f32 -kernel void kernel_ssm_conv_f32( +kernel void kernel_ssm_conv_f32_f32( + constant ggml_metal_kargs_ssm_conv & args, device const void * src0, device const void * src1, device float * dst, - constant ggml_metal_kargs_ssm_conv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -1925,6 +2029,7 @@ kernel void kernel_ssm_conv_f32( // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part kernel void kernel_ssm_scan_f32( + constant ggml_metal_kargs_ssm_scan & args, device const void * src0, device const void * src1, device const void * src2, @@ -1934,7 +2039,6 @@ kernel void kernel_ssm_scan_f32( device const void * src6, device float * dst, threadgroup float * shared [[threadgroup(0)]], - constant ggml_metal_kargs_ssm_scan & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]], @@ -1964,14 +2068,15 @@ kernel void kernel_ssm_scan_f32( device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); const int64_t i = i0 + i1*nc; + const int64_t g = ir / (nh / ng); // repeat_interleave float s0 = s0_buff[i]; float s = s_buff[i]; device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); for (int64_t i2 = 0; i2 < n_t; ++i2) { @@ -2039,7 +2144,8 @@ kernel void kernel_ssm_scan_f32( } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -kernel void kernel_ssm_scan_f32_group( +kernel void kernel_ssm_scan_group_f32( + constant ggml_metal_kargs_ssm_scan & args, device const void * src0, device const void * src1, device const void * src2, @@ -2049,7 +2155,6 @@ kernel void kernel_ssm_scan_f32_group( device const void * src6, device float * dst, threadgroup float * shared [[threadgroup(0)]], - constant ggml_metal_kargs_ssm_scan & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]], @@ -2079,14 +2184,15 @@ kernel void kernel_ssm_scan_f32_group( device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); const int64_t i = i0 + i1*nc; + const int64_t g = ir / (nh / ng); // repeat_interleave float s0 = s0_buff[i]; float s = s_buff[i]; device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); for (int64_t i2 = 0; i2 < n_t; ++i2) { @@ -2327,24 +2433,22 @@ kernel void kernel_rwkv_wkv7_f32( } } -kernel void kernel_argmax( - device const void * x, - device int32_t * dst, - constant int64_t & ncols, - constant uint64_t & nb01, - threadgroup float * shared_maxval [[threadgroup(0)]], - threadgroup int32_t * shared_argmax [[threadgroup(1)]], +kernel void kernel_argmax_f32( + constant ggml_metal_kargs_argmax & args, + device const char * src0, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01); + device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01); float lmax = -INFINITY; int32_t larg = -1; - for (int i00 = tpitg; i00 < ncols; i00 += ntg) { + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { if (x_row[i00] > lmax) { lmax = x_row[i00]; larg = i00; @@ -2355,6 +2459,11 @@ kernel void kernel_argmax( float max_val = simd_max(lmax); int32_t arg_val = simd_max(select(-1, larg, lmax == max_val)); + device int32_t * dst_i32 = (device int32_t *) dst; + + threadgroup float * shared_maxval = (threadgroup float *) shmem; + threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH; + if (ntg > N_SIMDWIDTH) { if (sgitg == 0) { shared_maxval[tiisg] = -INFINITY; @@ -2376,35 +2485,48 @@ kernel void kernel_argmax( float max_val_reduced = simd_max(max_val); int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced)); - dst[tgpig] = arg_val_reduced; + dst_i32[tgpig] = arg_val_reduced; return; } - dst[tgpig] = arg_val; + dst_i32[tgpig] = arg_val; } -kernel void kernel_norm( +// F == 1 : rms_norm (no fuse) +// F == 2 : rms_norm + mul +// F == 3 : rms_norm + mul + add +template +kernel void kernel_norm_fuse_impl( constant ggml_metal_kargs_norm & args, device const char * src0, + device const char * src1_0, + device const char * src1_1, device char * dst, threadgroup float * shmem_f32 [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - ushort tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { if (sgitg == 0) { shmem_f32[tiisg] = 0.0f; } - device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + const int i01 = tgpig.x; + const int i02 = tgpig.y; + const int i03 = tgpig.z; + + device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); + + device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); + device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); float4 sumf4(0.0f); float sumf = 0.0f; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { sumf4 += x[i00]; } sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3]; @@ -2423,10 +2545,10 @@ kernel void kernel_norm( const float mean = sumf/args.ne00; - device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); sumf = 0.0f; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { y[i00] = x[i00] - mean; sumf += dot(y[i00], y[i00]); } @@ -2446,11 +2568,25 @@ kernel void kernel_norm( const float variance = sumf/args.ne00; const float scale = 1.0f/sqrt(variance + args.eps); - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { - y[i00] = y[i00] * scale; + for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { + if (F == 1) { + y[i00] = (y[i00]*scale); + } + if (F == 2) { + y[i00] = (y[i00]*scale)*f0[i00]; + } + if (F == 3) { + y[i00] = (y[i00]*scale)*f0[i00] + f1[i00]; + } } } +typedef decltype(kernel_norm_fuse_impl<1>) kernel_norm_fuse_t; + +template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<1>; +template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<2>; +template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<3>; + // F == 1 : rms_norm (no fuse) // F == 2 : rms_norm + mul // F == 3 : rms_norm + mul + add @@ -2518,11 +2654,11 @@ kernel void kernel_rms_norm_fuse_impl( typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t; -template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>; -template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>; -template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>; +template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>; +template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>; +template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>; -kernel void kernel_l2_norm( +kernel void kernel_l2_norm_f32( constant ggml_metal_kargs_l2_norm & args, device const char * src0, device char * dst, @@ -2565,10 +2701,10 @@ kernel void kernel_l2_norm( } } -kernel void kernel_group_norm( +kernel void kernel_group_norm_f32( + constant ggml_metal_kargs_group_norm & args, device const float * src0, device float * dst, - constant ggml_metal_kargs_group_norm & args, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], @@ -2576,7 +2712,7 @@ kernel void kernel_group_norm( uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { const int64_t ne = args.ne00*args.ne01*args.ne02; - const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups); + const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp); int start = tgpig * gs; int end = start + gs; @@ -2734,7 +2870,52 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } -template +template +static inline void helper_mv_reduce_and_write( + device float * dst_f32, + float sumf[NR0], + const int r0, + const int ne01, + ushort tiisg, + ushort sgitg, + threadgroup char * shmem) { + constexpr short NW = N_SIMDWIDTH; + + threadgroup float * shmem_f32[NR0]; + + for (short row = 0; row < NR0; ++row) { + shmem_f32[row] = (threadgroup float *) shmem + NW*row; + + if (sgitg == 0) { + shmem_f32[row][tiisg] = 0.0f; + } + + sumf[row] = simd_sum(sumf[row]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short row = 0; row < NR0; ++row) { + if (tiisg == 0) { + shmem_f32[row][sgitg] = sumf[row]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short row = 0; row < NR0 && r0 + row < ne01; ++row) { + float tot = simd_sum(shmem_f32[row][tiisg]); + + if (tiisg == 0 && sgitg == 0) { + dst_f32[r0 + row] = tot; + } + } +} + +constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]]; +constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]]; + +template void mul_vec_q_n_f32_impl( args_t args, device const char * src0, @@ -2744,45 +2925,54 @@ void mul_vec_q_n_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { - const int nb = args.ne00/QK4_0; + const short NSG = FC_mul_mv_nsg; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; + constexpr short NW = N_SIMDWIDTH; + constexpr short NQ = 16; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int nb = args.ne00/QK4_0; + + const int r0 = (tgpig.x*NSG + sgitg)*NR0; + //const int r0 = tgpig.x*NR0; + const int r1 = tgpig.y; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q_type * ax[nr0]; - for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + device const block_q_type * ax[NR0]; + FOR_UNROLL (int row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } - float yl[16]; // src1 vector cache - float sumf[nr0] = {0.f}; + float sumf[NR0] = {0.f}; + + const short ix = (tiisg/(NW/NQ)); + const short il = (tiisg%(NW/NQ))*8; + + //const int ib0 = sgitg*NQ + ix; + const int ib0 = ix; - const short ix = (tiisg/2); - const short il = (tiisg%2)*8; + float yl[16]; // src1 vector cache - device const float * yb = y + ix*QK4_0 + il; + //device const float * yb = y + ix*QK4_0 + il; + device const float * yb = y + ib0*QK4_0 + il; // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { + //for (int ib = ib0; ib < nb; ib += NSG*NQ) { + for (int ib = ib0; ib < nb; ib += NQ) { float sumy[2] = { 0.f, 0.f }; -#pragma unroll - for (short i = 0; i < 8; i += 2) { + FOR_UNROLL (short i = 0; i < 8; i += 2) { sumy[0] += yb[i + 0] + yb[i + 1]; yl[i + 0] = yb[i + 0]; yl[i + 1] = yb[i + 1]/256.f; @@ -2792,21 +2982,23 @@ void mul_vec_q_n_f32_impl( yl[i + 9] = yb[i + 17]/4096.f; } -#pragma unroll - for (short row = 0; row < nr0; row++) { + FOR_UNROLL (short row = 0; row < NR0; row++) { sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } yb += QK4_0 * 16; + //yb += NSG*NQ*QK4_0; } device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; - for (int row = 0; row < nr0; ++row) { + //helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); + + for (int row = 0; row < NR0; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < args.ne01) { - dst_f32[first_row + row] = tot; + if (tiisg == 0 && r0 + row < args.ne01) { + dst_f32[r0 + row] = tot; } } } @@ -2816,10 +3008,11 @@ kernel void kernel_mul_mv_q4_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -2827,10 +3020,11 @@ kernel void kernel_mul_mv_q4_1_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -2838,10 +3032,11 @@ kernel void kernel_mul_mv_q5_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -2849,15 +3044,14 @@ kernel void kernel_mul_mv_q5_1_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -#define NB_Q8_0 8 - -template +template void kernel_mul_mv_q8_0_f32_impl( args_t args, device const char * src0, @@ -2867,66 +3061,68 @@ void kernel_mul_mv_q8_0_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NW = N_SIMDWIDTH; + constexpr short NQ = 8; + const int nb = args.ne00/QK8_0; - const int r0 = tgpig.x; + const int r0 = tgpig.x*NR0; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; - const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q8_0 * ax[nr0]; - for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + device const block_q8_0 * ax[NR0]; + FOR_UNROLL (short row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } - float yl[NB_Q8_0]; - float sumf[nr0] = { 0.f }; + float sumf[NR0] = { 0.f }; - const short ix = tiisg/4; - const short il = tiisg%4; + const short ix = tiisg/(NW/NQ); + const short il = tiisg%(NW/NQ); - device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; + const int ib0 = sgitg*NQ + ix; - // each thread in a SIMD group deals with NB_Q8_0 quants at a time - for (int ib = ix; ib < nb; ib += nw/4) { - for (short i = 0; i < NB_Q8_0; ++i) { + float yl[NQ]; + + device const float * yb = y + ib0*QK8_0 + il*NQ; + + // each thread in a SIMD group deals with NQ quants at a time + for (int ib = ib0; ib < nb; ib += NSG*NQ) { + for (short i = 0; i < NQ; ++i) { yl[i] = yb[i]; } - for (short row = 0; row < nr0; row++) { - device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; + for (short row = 0; row < NR0; row++) { + device const int8_t * qs = ax[row][ib].qs + il*NQ; + float sumq = 0.f; - for (short iq = 0; iq < NB_Q8_0; ++iq) { - sumq += qs[iq] * yl[iq]; + FOR_UNROLL (short i = 0; i < NQ; ++i) { + sumq += qs[i] * yl[i]; } + sumf[row] += sumq*ax[row][ib].d; } - yb += nw*NB_Q8_0; + yb += NSG*NQ*QK8_0; } device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0; ++row) { - const float tot = simd_sum(sumf[row]); - - if (tiisg == 0 && first_row + row < args.ne01) { - dst_f32[first_row + row] = tot; - } - } + helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } [[host_name("kernel_mul_mv_q8_0_f32")]] @@ -2935,15 +3131,16 @@ kernel void kernel_mul_mv_q8_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } // mat-vec kernel processing in chunks of float4 // chpb - chunks per quantization block -template +template void kernel_mul_mv_ext_q4_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, device const char * src0, @@ -2952,6 +3149,9 @@ void kernel_mul_mv_ext_q4_f32_impl( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short NSG = FC_mul_mv_nsg; + const short nxpsg = FC_mul_mv_nxpsg; + const short chpt = 4; // chunks per thread //const short nxpsg = (32); @@ -2960,7 +3160,7 @@ void kernel_mul_mv_ext_q4_f32_impl( const short tx = tiisg%nxpsg; const short ty = tiisg/nxpsg; - const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty; const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; @@ -3001,7 +3201,6 @@ void kernel_mul_mv_ext_q4_f32_impl( #pragma unroll(r1ptg) for (short ir1 = 0; ir1 < r1ptg; ++ir1) { sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]); - } } @@ -3044,7 +3243,7 @@ void kernel_mul_mv_ext_q4_f32_impl( } // mat-vec kernel processing in chunks of float4x4 -template +template void kernel_mul_mv_ext_q4x4_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, device const char * src0, @@ -3053,6 +3252,9 @@ void kernel_mul_mv_ext_q4x4_f32_impl( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short NSG = FC_mul_mv_nsg; + const short nxpsg = FC_mul_mv_nxpsg; + const short chpt = 1; //const short nxpsg = (32); @@ -3061,7 +3263,7 @@ void kernel_mul_mv_ext_q4x4_f32_impl( const short tx = tiisg%nxpsg; const short ty = tiisg/nxpsg; - const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty; const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; @@ -3158,12 +3360,7 @@ kernel void kernel_mul_mv_ext_q4_f32_disp( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - switch (args.nxpsg) { - case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - } + kernel_mul_mv_ext_q4_f32_impl(args, src0, src1, dst, tgpig, tiisg, sgitg); } template @@ -3175,17 +3372,17 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - switch (args.nxpsg) { - case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; - } + kernel_mul_mv_ext_q4x4_f32_impl(args, src0, src1, dst, tgpig, tiisg, sgitg); } typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t; typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>; + template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>; template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>; template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>; @@ -3241,104 +3438,217 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4 template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; -#define N_MV_T_T 4 - -template -void kernel_mul_mv_impl( +template +void kernel_mul_mv_t_t_impl( args_t args, device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem, uint3 tgpig, - ushort tiisg) { - const int r0 = tgpig.x; - const int rb = tgpig.y*N_MV_T_T; + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NW = N_SIMDWIDTH; + constexpr short NB = 32; + constexpr short NF = 8; + + const int nb = args.ne00/NB; + + const int r0 = tgpig.x*NR0; + const int r1 = tgpig.y; const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const T0 * x = (device const T0 *) (src0 + offset0); + //device const T0 * x = (device const T0 *) (src0 + offset0); + device const T1 * y = (device const T1 *) (src1 + offset1); - device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + // pointers to src0 rows + device const T0 * ax [NR0]; + FOR_UNROLL (short row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - if (args.ne00 < 128) { - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= args.ne11) { - break; - } + ax[row] = (device const T0 *) ((device char *) src0 + offset0); + } - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + float sumf[NR0] = { 0.f }; - device const T1 * y = (device const T1 *) (src1 + offset1); + const short ix = tiisg/(NW/NF); + const short il = tiisg%(NW/NF); - float sumf = 0; - for (int i = tiisg; i < args.ne00; i += 32) { - sumf += (T0) x[i] * (T1) y[i]; - } + const int ib0 = sgitg*NF + ix; - float sum_all = simd_sum(sumf); - if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; - } + T1 yl[NF]; + + device const T1 * yb = y + (ib0*NB + il*NF); + + for (int ib = ib0; ib < nb; ib += NSG*NF) { + for (short i = 0; i < NF; ++i) { + yl[i] = yb[i]; } - } else { - device const T04 * x4 = (device const T04 *) x; - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= args.ne11) { - break; + + for (short row = 0; row < NR0; row++) { + device const T0 * xb = ax[row] + (ib*NB + il*NF); + + float sumq = 0.f; + FOR_UNROLL (short i = 0; i < NF; ++i) { + sumq += xb[i] * yl[i]; } - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + sumf[row] += sumq; + } - device const T1 * y = (device const T1 *) (src1 + offset1); - device const T14 * y4 = (device const T14 *) y; + yb += NSG*NF*NW; + } - float sumf = 0; - for (int i = tiisg; i < args.ne00/4; i += 32) { - sumf += dot((float4) x4[i], (float4) y4[i]); - } + for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) { + for (short row = 0; row < NR0; row++) { + sumf[row] += ax[row][i] * y[i]; + } + } - float sum_all = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); - dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); +} + +template +kernel void kernel_mul_mv_t_t( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +typedef decltype(kernel_mul_mv_t_t) mul_mv_t_t; + +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +#endif + +template +void kernel_mul_mv_t_t_4_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + constexpr short NW = N_SIMDWIDTH; + constexpr short NB = 32; + constexpr short NF = 16; + constexpr short NF4 = NF/4; + + const int nb = args.ne00/NB; + + const int r0 = tgpig.x*NR0; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); + device const T14 * y4 = (device const T14 *) (src1 + offset1); + + // pointers to src0 rows + device const T0 * ax [NR0]; + device const T04 * ax4[NR0]; + FOR_UNROLL (short row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + ax [row] = (device const T0 *) ((device char *) src0 + offset0); + ax4[row] = (device const T04 *) ((device char *) src0 + offset0); + } + + float sumf[NR0] = { 0.f }; + + const short ix = tiisg/(NW/NF); + const short il = tiisg%(NW/NF); + + const int ib0 = sgitg*NF + ix; + + T14 yl4[NF4]; + + device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4; + + for (int ib = ib0; ib < nb; ib += NSG*NF) { + for (short i = 0; i < NF4; ++i) { + yl4[i] = yb4[i]; + } + + for (short row = 0; row < NR0; row++) { + device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4; + + float sumq = 0.f; + FOR_UNROLL (short i = 0; i < NF4; ++i) { + sumq += dot(float4(xb4[i]), float4(yl4[i])); } + + sumf[row] += sumq; + } + + yb4 += NSG*NF*NW/4; + } + + for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) { + for (short row = 0; row < NR0; row++) { + sumf[row] += ax[row][i] * y[i]; } } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } -template -kernel void kernel_mul_mv( +template +kernel void kernel_mul_mv_t_t_4( constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_impl( - args, - src0, - src1, - dst, - tgpig, - tiisg); + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(kernel_mul_mv) mul_mv_t; +typedef decltype(kernel_mul_mv_t_t_4) mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv; -template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv; -template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv; -template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; #endif +#define N_MV_T_T 4 + template void kernel_mul_mv_c4_impl( args_t args, @@ -3399,112 +3709,10 @@ typedef decltype(kernel_mul_mv_c4) mul_mv_c4_t; template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -#endif - -template -kernel void kernel_mul_mv_1row( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]]) { - - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; - - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - - device const T * x = (device const T *) (src0 + offset0); - device const float * y = (device const float *) (src1 + offset1); - - device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - - float sumf = 0; - if (args.ne00 < 128) { - for (int i = tiisg; i < args.ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - float sum_all = simd_sum(sumf); - if (tiisg == 0) { - dst_f32[r0] = sum_all; - } - } else { - device const T4 * x4 = (device const T4 *) x; - device const float4 * y4 = (device const float4 *) y; - - for (int i = tiisg; i < args.ne00/4; i += 32) { - sumf += dot((float4) x4[i], y4[i]); - } - - float sum_all = simd_sum(sumf); - - if (tiisg == 0) { - for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); - dst_f32[r0] = sum_all; - } - } -} - -typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t; - -template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; -#endif - -// Assumes row size (ne00) is a multiple of 4 -template -kernel void kernel_mul_mv_l4( - constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]]) { - - const int nrows = args.ne11; - const int r0 = tgpig.x; - const int im = tgpig.z; - - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; - - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - - device const T4 * x4 = (device const T4 *) (src0 + offset0); - - device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; - - for (int r1 = 0; r1 < nrows; ++r1) { - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - - device const float4 * y4 = (device const float4 *) (src1 + offset1); - - float sumf = 0; - for (int i = tiisg; i < args.ne00/4; i += 32) { - sumf += dot((float4) x4[i], y4[i]); - } - - float sum_all = simd_sum(sumf); - if (tiisg == 0) { - dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; - } - } -} - -typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; - -template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; +template [[host_name("kernel_mul_mv_f16_f16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; +template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; #endif static float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -3807,9 +4015,9 @@ template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t ker template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; typedef void (im2col_t)( + constant ggml_metal_kargs_im2col & args, device const float * x, device char * dst, - constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -3817,9 +4025,9 @@ typedef void (im2col_t)( template kernel void kernel_im2col( + constant ggml_metal_kargs_im2col & args, device const float * x, device char * dst, - constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -3828,11 +4036,10 @@ kernel void kernel_im2col( const int64_t OH = tgpg[1]; const int64_t OW = tgpg[2]; -// const int64_t N = ntg[0]; const int64_t KH = ntg[1]; const int64_t KW = ntg[2]; - const int64_t in = tpitg[0]; + int64_t in = tpitg[0]; const int64_t ikh = tpitg[1]; const int64_t ikw = tpitg[2]; @@ -3843,88 +4050,96 @@ kernel void kernel_im2col( const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0; const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1; - const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); + int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); device T * pdst = (device T *) (dst); if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { pdst[offset_dst] = 0.0f; } else { - const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; - pdst[offset_dst] = x[offset_src]; - } -} + int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; -template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; -template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; + while (in < args.N) { + pdst[offset_dst] = x[offset_src]; -typedef void (im2col_ext_t)( - device const float * x, - device char * dst, - constant ggml_metal_kargs_im2col & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]); + offset_dst += ntg[0]*args.CHW*OH*OW; + offset_src += ntg[0]*args.ofs0; -template -kernel void kernel_im2col_ext( - device const float * x, - device char * dst, - constant ggml_metal_kargs_im2col & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] - const int64_t KHW = (int64_t)args.KHW; - - const int64_t d = tgpig[0] / args.CHW; - const int64_t chw = tgpig[0] % args.CHW; - const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) - const int64_t HW = tgpig[0] % KHW; - - const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; - if (tpitg_0 >= args.N) { - return; - } - - const int64_t tpitg_1 = HW / args.KW; - const int64_t tpitg_2 = HW % args.KW; - - const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; - const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; - - const int64_t offset_dst = - (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + - (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); - - device T * pdst = (device T *) (dst); - - if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { - pdst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; - pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; + in += ntg[0]; + } } } -template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; -template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; +template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; +template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; + +//typedef void (im2col_ext_t)( +// constant ggml_metal_kargs_im2col & args, +// device const float * x, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// uint3 tgpg[[threadgroups_per_grid]], +// uint3 tpitg[[thread_position_in_threadgroup]], +// uint3 ntg[[threads_per_threadgroup]]); +// +//template +//kernel void kernel_im2col_ext( +// constant ggml_metal_kargs_im2col & args, +// device const float * x, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW +// uint3 tpitg[[thread_position_in_threadgroup]], +// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] +// const int64_t KHW = (int64_t)args.KHW; +// +// const int64_t d = tgpig[0] / args.CHW; +// const int64_t chw = tgpig[0] % args.CHW; +// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) +// const int64_t HW = tgpig[0] % KHW; +// +// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; +// if (tpitg_0 >= args.N) { +// return; +// } +// +// const int64_t tpitg_1 = HW / args.KW; +// const int64_t tpitg_2 = HW % args.KW; +// +// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; +// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; +// +// const int64_t offset_dst = +// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + +// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); +// +// device T * pdst = (device T *) (dst); +// +// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { +// pdst[offset_dst] = 0.0f; +// } else { +// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; +// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; +// } +//} +// +//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; +//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; typedef void (conv_transpose_1d_t)( + constant ggml_metal_kargs_conv_transpose_1d & args, device const float * src0, device const float * src1, device char * dst, - constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); template kernel void kernel_conv_transpose_1d( + constant ggml_metal_kargs_conv_transpose_1d & args, device const T * src0, device const float * src1, device char * dst, - constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]) { @@ -3948,26 +4163,26 @@ kernel void kernel_conv_transpose_1d( template [[host_name("kernel_conv_transpose_1d_f32_f32")]] kernel void kernel_conv_transpose_1d( + constant ggml_metal_kargs_conv_transpose_1d & args, device const float * src0, device const float * src1, device char * dst, - constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); template [[host_name("kernel_conv_transpose_1d_f16_f32")]] kernel void kernel_conv_transpose_1d( + constant ggml_metal_kargs_conv_transpose_1d & args, device const half * src0, device const float * src1, device char * dst, - constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); kernel void kernel_upscale_f32( + constant ggml_metal_kargs_upscale & args, device const char * src0, device char * dst, - constant ggml_metal_kargs_upscale & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -3991,9 +4206,9 @@ kernel void kernel_upscale_f32( } kernel void kernel_pad_f32( + constant ggml_metal_kargs_pad & args, device const char * src0, device char * dst, - constant ggml_metal_kargs_pad & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -4027,9 +4242,9 @@ kernel void kernel_pad_f32( } kernel void kernel_pad_reflect_1d_f32( + constant ggml_metal_kargs_pad_reflect_1d & args, device const char * src0, device char * dst, - constant ggml_metal_kargs_pad_reflect_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -4060,8 +4275,8 @@ kernel void kernel_pad_reflect_1d_f32( } kernel void kernel_arange_f32( - device char * dst, constant ggml_metal_kargs_arange & args, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -4074,9 +4289,9 @@ kernel void kernel_arange_f32( } kernel void kernel_timestep_embedding_f32( + constant ggml_metal_kargs_timestep_embedding & args, device const char * src0, device char * dst, - constant ggml_metal_kargs_timestep_embedding & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -4094,25 +4309,25 @@ kernel void kernel_timestep_embedding_f32( } if (args.dim % 2 != 0 && tpitg.x == 0) { - embed_data[args.dim] = 0.f; + embed_data[2 * half_] = 0.f; } } // bitonic sort implementation following the CUDA kernels as reference typedef void (argsort_t)( - device const float * x, - device int32_t * dst, constant ggml_metal_kargs_argsort & args, + device const float * x, + device int32_t * dst, threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]); template kernel void kernel_argsort_f32_i32( - device const float * x, - device int32_t * dst, constant ggml_metal_kargs_argsort & args, - threadgroup int32_t * shared_values [[threadgroup(0)]], + device const float * x, + device int32_t * dst, + threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]) { // bitonic sort @@ -4165,13 +4380,36 @@ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_ar template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; kernel void kernel_leaky_relu_f32( + constant ggml_metal_kargs_leaky_relu & args, device const float * src0, device float * dst, + uint tpig[[thread_position_in_grid]]) { + const float x = src0[tpig]; + dst[tpig] = x > 0.0f ? x : x * args.slope; +} + +kernel void kernel_leaky_relu_f32_4( constant ggml_metal_kargs_leaky_relu & args, + device const float4 * src0, + device float4 * dst, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope; + const float4 x = src0[tpig]; + dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope); } +constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; +constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; +constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; +constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]]; + +//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; +//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]]; +//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]]; + +constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]]; +constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]]; +constant int32_t FC_flash_attn_ext_nsg [[function_constant(FC_FLASH_ATTN_EXT + 22)]]; + // ref: https://arxiv.org/pdf/2307.08691.pdf template< typename q_t, // query types in shared memory @@ -4186,6 +4424,7 @@ template< typename qk_t, // Q*K types typename qk8x8_t, typename s_t, // soft-max types + typename s2_t, typename s8x8_t, typename o_t, // attention accumulation types typename o4_t, @@ -4196,12 +4435,12 @@ template< typename vd4x4_t, // value type in device memory short nl_v, void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), - short DK, // K head size - short DV, // V head size - short Q = 8, // queries per threadgroup - short KV = 8, // key/value processed per each simdgroup - short C = 32> // cache items per threadgroup -kernel void kernel_flash_attn_ext( + short DK, // K head size + short DV, // V head size + short Q, // queries per threadgroup + short C, // cache items per threadgroup + short NSG> // number of simd groups +void kernel_flash_attn_ext_impl( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, device const char * k, @@ -4209,46 +4448,85 @@ kernel void kernel_flash_attn_ext( device const char * mask, device const char * sinks, device char * dst, - threadgroup half * shmem_f16 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups - - const int iq3 = tgpig[2]; - const int iq2 = tgpig[1]; - const int iq1 = tgpig[0]*Q; + threadgroup half * shmem_f16, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const ushort iq3 = tgpig[2]; + const ushort iq2 = tgpig[1]; + const ushort iq1 = tgpig[0]*Q; + +#define NS10 (FC_flash_attn_ext_ns10) +#define NS20 (FC_flash_attn_ext_ns20) + + // note: I had some concerns that using this instead of the ugly macros above was affecting performance + // need to re-check carefully and if no regressions are observerd - remove the macros + // the concerns is that maybe using const variables requires extra registers? but not sure if the compiler + // is clever enough to avoid this. unfortunately, using constexpr is not possible with FC + //const short NS10 = FC_flash_attn_ext_ns10; + //const short NS20 = FC_flash_attn_ext_ns20; + + constexpr short KV = 8; constexpr short DK4 = DK/4; constexpr short DK8 = DK/8; constexpr short DK16 = DK/16; constexpr short DV4 = DV/4; - constexpr short DV8 = DV/8; + //constexpr short DV8 = DV/8; constexpr short DV16 = DV/16; + constexpr short PV = PAD2(DV, 64); + constexpr short PV4 = PV/4; + constexpr short PV8 = PV/8; + //constexpr short PV16 = PV/16; + constexpr short NW = N_SIMDWIDTH; - constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) + constexpr short NQ = Q/NSG; + constexpr short SH = 2*C; // shared memory per simdgroup (s_t == float) + + constexpr short TS = 2*SH; + constexpr short T = DK + 2*PV; // shared memory size per query in (half) - const short TS = nsg*SH; // shared memory size per query in (s_t == float) - const short T = 2*DK + 2*TS; // shared memory size per query in (half) + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*T); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper) + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK); + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix + threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t - threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t - threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory - threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t - threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory - threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + // mask storage in shared mem + threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C); - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - o8x8_t lo[DV8]; + // per-query mask pointers + device const half2 * pm2[NQ]; + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); + } + + { + q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += ikv2*args.nb12 + ikv3*args.nb13; + v += ikv2*args.nb22 + ikv3*args.nb23; + } // load heads from Q to shared memory - for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01); for (short i = tiisg; i < DK4; i += NW) { if (iq1 + j < args.ne01) { @@ -4259,43 +4537,30 @@ kernel void kernel_flash_attn_ext( } } - // zero out lo - for (short i = 0; i < DV8; ++i) { - lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f); - } + // zero out + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] = 0; + } - // zero out shared memory SH - for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*TS + i] = 0.0f; + ss[j*SH + i] = 0.0f; } } threadgroup_barrier(mem_flags::mem_threadgroup); - { - float S[Q] = { [0 ... Q-1] = 0.0f }; - float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 }; - - // thread indices inside the simdgroup - // TODO: see if we can utilize quad-group functions for better performance - // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3) - const short tx = tiisg%4; - const short ty = tiisg/4; - - // broadcast kv - //const short rk2 = args.ne02/args.ne12; - //const short rk3 = args.ne03/args.ne13; - - const short ikv2 = iq2/(args.ne02/args.ne_12_2); - const short ikv3 = iq3/(args.ne03/args.ne_12_3); + float S[NQ] = { [0 ... NQ-1] = 0.0f }; - const bool has_mask = mask != q; + { + float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 }; float slope = 1.0f; // ALiBi - if (args.max_bias > 0.0f) { + if (FC_flash_attn_ext_has_bias) { const short h = iq2; const float base = h < args.n_head_log2 ? args.m0 : args.m1; @@ -4306,177 +4571,277 @@ kernel void kernel_flash_attn_ext( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; - if (ic >= args.ne11) { - break; - } - - if (has_mask) { - // used to detect blocks full of -INF - float smax = -INFINITY; + for (int ic = 0; ic < args.ne11; ic += C) { + // read the mask into shared mem + if (FC_flash_attn_ext_has_mask) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + pm2[jj] += NW; + } - // load the mask in shared memory - #pragma unroll(Q) - for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); + threadgroup_barrier(mem_flags::mem_threadgroup); - const float m = pm[ic + tiisg]; + // used to detect blocks full of -INF + // skip only when the entire threadgroup is masked + half2 smax2(-MAXHALF/2, -MAXHALF/2); - ss[j*TS + C + tiisg] = m; - smax = max(smax, m); + FOR_UNROLL (short j = 0; j < Q; ++j) { + smax2 = max(smax2, sm2[j*SH + tiisg]); } - smax = simd_max(smax); + smax2 = simd_max(smax2); + + if (max(smax2[0], smax2[1]) <= -MAXHALF/2) { + // this barrier is important + threadgroup_barrier(mem_flags::mem_threadgroup); - if (smax == -INFINITY) { continue; } } // Q*K^T - { - for (short cc = 0; cc < C/8; ++cc) { + // this is compile-time check, so it does not have runtime overhead + if (is_same::value) { + // we can read directly from global memory + device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11); + threadgroup const q_t * pq = sq; + threadgroup s_t * ps = ss; + + pk += sgitg*(8*NS10); + ps += sgitg*(8*1); + + static_assert((C/8) % NSG == 0, ""); + + constexpr short NC = (C/8)/NSG; + + // TODO: not good to unroll for large contexts - not sure why? + for (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - // this is compile-time check, so it does not have runtime overhead - if (is_same::value) { - // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + if (DK8 % 16 != 0) { + k8x8_t mk; + q8x8_t mq; + + FOR_UNROLL (short i = 0; i < DK8; ++i) { + simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(DK8) - for (short i = 0; i < DK8; ++i) { - k8x8_t mk; - simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + simdgroup_load(mk, pk, NS10, 0, true); + simdgroup_load(mq, pq, DK); + + simdgroup_barrier(mem_flags::mem_none); - q8x8_t mq; - simdgroup_load(mq, sq + i*8, DK); simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + pk += 8; + pq += 8; } } else { - for (short ii = 0; ii < DK16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + k8x8_t mk[2]; + q8x8_t mq[2]; - if (DK16%4 == 0) { - // the head is evenly divisible by 4*16 = 64, so no need for bound checks - { - k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); - sk4x4[4*ty + tx] = tmp; - } + FOR_UNROLL (short i = 0; i < DK8/2; ++i) { + simdgroup_barrier(mem_flags::mem_none); - simdgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_load(mk[0], pk + 0*8, NS10, 0, true); + simdgroup_load(mk[1], pk + 1*8, NS10, 0, true); - #pragma unroll(4) - for (short k = 0; k < 4; ++k) { - k8x8_t mk; - q8x8_t mq; + simdgroup_load(mq[0], pq + 0*8, DK); + simdgroup_load(mq[1], pq + 1*8, DK); - simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - } - } else { - if (ii + tx < DK16) { - k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); - sk4x4[4*ty + tx] = tmp; - } + simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); + simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); - simdgroup_barrier(mem_flags::mem_threadgroup); + pk += 16; + pq += 16; + } + } - for (short k = 0; k < 4 && ii + k < DK16; ++k) { - k8x8_t mk; - q8x8_t mq; + simdgroup_store(mqk, ps, SH, 0, false); - simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + pk += 8*(NSG*NS10 - DK8); + pq += 8*(NSG*0 - DK8); + ps += 8*(NSG); + } + } else { + // TODO: this is the quantized K cache branch - not optimized yet + for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) { + const short cc = ccc*NSG + sgitg; - simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - } + const short tx = tiisg%4; + const short ty = tiisg/4; + + qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); + + for (short ii = 0; ii < DK16; ii += 4) { + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11)); + + if (DK16%4 == 0) { + // the head is evenly divisible by 4*16 = 64, so no need for bound checks + { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short k = 0; k < 4; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + if (ii + tx < DK16) { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DK16; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } } - // cast qk_t -> s_t - //s8x8_t mqks(1.0f); - //simdgroup_multiply(mqks, mqk, mqks); - //simdgroup_store(mqks, ss + 8*cc, TS, 0, false); - - simdgroup_store(mqk, ss + 8*cc, TS, 0, false); + simdgroup_store(mqk, ss + 8*cc, SH, 0, false); } } + threadgroup_barrier(mem_flags::mem_threadgroup); + // online softmax - { - for (ushort j = 0; j < Q; ++j) { - const float m = M[j]; + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - // scale and apply the logitcap / mask - float s = ss[j*TS + tiisg]*args.scale; + const float m = M[jj]; - if (args.logit_softcap != 0.0f) { - s = args.logit_softcap*precise::tanh(s); - } + // scale and apply the logitcap / mask + float2 s2 = ss2[j*SH/2 + tiisg]*args.scale; - // mqk = mqk + mask*slope - s += slope*ss[j*TS + C + tiisg]; + if (FC_flash_attn_ext_has_scap) { + s2 = args.logit_softcap*precise::tanh(s2); + } - M[j] = simd_max(max(M[j], s)); + // mqk = mqk + slope*mask + if (FC_flash_attn_ext_has_bias) { + s2 += s2_t(sm2[j*SH + tiisg])*slope; + } else { + s2 += s2_t(sm2[j*SH + tiisg]); + } - const float ms = exp(m - M[j]); - const float vs = exp(s - M[j]); + M[jj] = simd_max(max(M[jj], max(s2[0], s2[1]))); - S[j] = S[j]*ms + simd_sum(vs); + const float ms = exp(m - M[jj]); + const float2 vs2 = exp(s2 - M[jj]); - // the P matrix from the paper (Q rows, C columns) - ss[j*TS + tiisg] = vs; + S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]); - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*TS + 2*C + j] = ms; - } - } - } + // the P matrix from the paper (Q rows, C columns) + ss2[j*SH/2 + tiisg] = vs2; - // O = diag(ms)*O - { - s8x8_t ms; - simdgroup_load(ms, ss + 2*C, TS, 0, false); + if (DV4 % NW == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { + const short i = ii*NW + tiisg; - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], ms, lo[i]); + so4[j*PV4 + i] *= ms; + } + } else { + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] *= ms; + } } } + threadgroup_barrier(mem_flags::mem_threadgroup); + // O = O + (Q*K^T)*V { - for (short cc = 0; cc < C/8; ++cc) { - s8x8_t vs; - simdgroup_load(vs, ss + 8*cc, TS, 0, false); + // we can read directly from global memory + if (is_same::value) { + static_assert(PV8 % NSG == 0, ""); + + constexpr short NO = PV8/NSG; - if (is_same::value) { - // we can read directly from global memory - device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + o8x8_t lo[NO]; - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - v8x8_t mv; - simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 + { + auto sot = so + 8*sgitg; - simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]); + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_load(lo[ii], sot, PV, 0, false); + + sot += 8*NSG; } - } else { - for (short ii = 0; ii < DV16; ii += 4) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + } + + { + auto sst = ss; + + device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21); + + pv += 8*sgitg; + + FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, sst, SH, 0, false); + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + v8x8_t mv; + + simdgroup_load(mv, pv, NS20, 0, false); + simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]); + + pv += 8*NSG; + } + + pv += 8*(NS20 - NO*NSG); + sst += 8; + } + } + + { + auto sot = so + 8*sgitg; + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_store(lo[ii], sot, PV, 0, false); + + sot += 8*NSG; + } + } + } else { + // TODO: this is the quantized V cache branch - not optimized yet + + const short tx = tiisg%4; + const short ty = tiisg/4; + + for (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); + + for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21)); if (DV16%4 == 0) { // no need for bound checks @@ -4488,15 +4853,20 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); - #pragma unroll(4) - for (short k = 0; k < 4; ++k) { - v8x8_t mv; + FOR_UNROLL (short k = 0; k < 4; ++k) { + v8x8_t mv[2]; + o8x8_t lo[2]; + + simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); - simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); + simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); - simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); } } else { if (ii + tx < DV16) { @@ -4508,236 +4878,249 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); for (short k = 0; k < 4 && ii + k < DV16; ++k) { - v8x8_t mv; + v8x8_t mv[2]; + o8x8_t lo[2]; - simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); - simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); + simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); + + simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); } } } } } } - } - if (sinks != q && sgitg == 0) { - for (ushort j = 0; j < Q; ++j) { - const float m = M[j]; - const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; + threadgroup_barrier(mem_flags::mem_threadgroup); + } - M[j] = simd_max(max(M[j], s)); + if (FC_flash_attn_ext_has_sinks) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - const float ms = exp(m - M[j]); - const float vs = exp(s - M[j]); + const float m = M[jj]; + const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; - S[j] = S[j]*ms + simd_sum(vs); + M[jj] = simd_max(max(M[jj], s)); - if (tiisg == j) { - ss[j*TS + 2*C + j] = ms; - } - } + const float ms = exp(m - M[jj]); + const float vs = exp(s - M[jj]); - // O = diag(ms)*O - { - s8x8_t ms; - simdgroup_load(ms, ss + 2*C, TS, 0, false); + S[jj] = S[jj]*ms + simd_sum(vs); - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], ms, lo[i]); + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] *= ms; } } } - - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (short j = tiisg; j < Q; j += NW) { - ss[j*TS + 0] = S[j]; - ss[j*TS + 1] = M[j]; - } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation - threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK); - - // store result to shared memory in F32 - if (sgitg == 0) { - for (short i = 0; i < DV8; ++i) { - //simdgroup_store(lo[i], so + i*8, DV, 0, false); - simdgroup_float8x8 t(1.0f); - simdgroup_multiply(t, lo[i], t); - simdgroup_store(t, so + i*8, DV, 0, false); + // store to global memory + for (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + if (iq1 + j >= args.ne01) { + break; } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // reduce the warps sequentially - for (ushort sg = 1; sg < nsg; ++sg) { - if (sgitg == sg) { - for (short j = tiisg; j < Q; j += NW) { - const float S0 = ss[j*TS - 1*SH + 0]; - const float S1 = ss[j*TS + 0]; - - const float M0 = ss[j*TS - 1*SH + 1]; - const float M1 = ss[j*TS + 1]; - const float M = max(M0, M1); - - float ms0 = exp(M0 - M); - float ms1 = exp(M1 - M); + device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; - const float S = S0*ms0 + S1*ms1; + const float scale = 1.0f/S[jj]; - ss[j*TS + 0] = S; - ss[j*TS + 1] = M; + if (DV4 % NW == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { + const short i = ii*NW + tiisg; - ss[j*TS + 2*C + j - 1*SH] = ms0; - ss[j*TS + 2*C + j ] = ms1; + dst4[i] = (float4) so4[j*PV4 + i]*scale; } - - //simdgroup_barrier(mem_flags::mem_threadgroup); - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { - s8x8_t ms0; - s8x8_t ms1; - - simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false); - simdgroup_load(ms1, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_float8x8 t; - - simdgroup_load (t, so + i*8, DV, 0, false); - simdgroup_multiply(t, ms0, t); - - simdgroup_multiply_accumulate(t, ms1, lo[i], t); - simdgroup_store(t, so + i*8, DV, 0, false); - } + } else { + for (short i = tiisg; i < DV4; i += NW) { + dst4[i] = (float4) so4[j*PV4 + i]*scale; } } - - threadgroup_barrier(mem_flags::mem_threadgroup); } - threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK); - - // final rescale with 1/S and store to global memory - for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { - const float S = 1.0f/sf[j*TS + 0]; - - device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; +#undef NS10 +#undef NS20 +} - for (short i = tiisg; i < DV4; i += NW) { - dst4[i] = (float4) so4[j*DV4 + i]*S; - } +template< + typename q_t, // query types in shared memory + typename q4_t, + typename q8x8_t, + typename k_t, // key types in shared memory + typename k4x4_t, + typename k8x8_t, + typename v_t, // value types in shared memory + typename v4x4_t, + typename v8x8_t, + typename qk_t, // Q*K types + typename qk8x8_t, + typename s_t, // soft-max types + typename s2_t, + typename s8x8_t, + typename o_t, // attention accumulation types + typename o4_t, + typename o8x8_t, + typename kd4x4_t, // key type in device memory + short nl_k, + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // value type in device memory + short nl_v, + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), + short DK, // K head size + short DV, // V head size + short Q = 8, // queries per threadgroup + short C = 64> // cache items per threadgroup +kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device const char * sinks, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C +#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg + switch (FC_flash_attn_ext_nsg) { + // note: disabled cases to reduce library load time + //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; + //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; } +#undef FWD_TMPL +#undef FWD_ARGS } // TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as // template to be able to explore different combinations // #define FA_TYPES \ - float, float4, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ - half, half4, simdgroup_half8x8 - //float, float4, simdgroup_float8x8 + float, float2, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 #define FA_TYPES_BF \ bfloat, bfloat4, simdgroup_bfloat8x8, \ bfloat, bfloat4x4, simdgroup_bfloat8x8, \ bfloat, bfloat4x4, simdgroup_bfloat8x8, \ float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ + float, float2, simdgroup_float8x8, \ half, half4, simdgroup_half8x8 //float, float4, simdgroup_float8x8 typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; - -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif -template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; - -template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; - -template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; - -template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; - -template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES #undef FA_TYPES_BF +constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]]; +constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; +constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]]; +constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]]; + +//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]]; +//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]]; +//constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]]; + +constant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]]; +constant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]]; +constant int32_t FC_flash_attn_ext_vec_nsg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]]; +constant int32_t FC_flash_attn_ext_vec_nwg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]]; + template< typename q4_t, // query types in shared memory typename k4_t, // key types in shared memory @@ -4756,9 +5139,10 @@ template< short DV, // V head size short NE = 4, // head elements per thread short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup -kernel void kernel_flash_attn_ext_vec( - constant ggml_metal_kargs_flash_attn_ext & args, + short C = 32, // cache items per threadgroup + short NSG> // number of simd groups +void kernel_flash_attn_ext_vec_impl( + constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, device const char * k, device const char * v, @@ -4767,47 +5151,74 @@ kernel void kernel_flash_attn_ext_vec( device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups + static_assert(DK % 32 == 0, "DK must be divisible by 32"); + static_assert(DV % 32 == 0, "DV must be divisible by 32"); - const int iq3 = tgpig[2]; - const int iq2 = tgpig[1]; - const int iq1 = tgpig[0]; +#define NWG (FC_flash_attn_ext_vec_nwg) + +#define NS10 (FC_flash_attn_ext_vec_ns10) +#define NS20 (FC_flash_attn_ext_vec_ns20) + + const short iwg = tgpig[2]%NWG; + + const ushort iq3 = tgpig[2]/NWG; + const ushort iq2 = tgpig[1]; + const ushort iq1 = tgpig[0]; constexpr short DK4 = DK/4; constexpr short DV4 = DV/4; + + constexpr short PK = PAD2(DK, 128); + constexpr short PK4 = PK/4; + + constexpr short PV = PAD2(DV, 128); + constexpr short PV4 = PV/4; + constexpr short NW = N_SIMDWIDTH; constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads constexpr short SH = 4*C; // shared memory per simdgroup - const short T = DK + nsg*SH; // shared memory size per query in (half) + static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); + static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); + + const short T = PK + NSG*SH; // shared memory size per query in (half) + + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t - threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask - threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results + // store the result for all queries in shared memory (the O matrix from the paper) + so4 += tiisg; - // store the result for all queries in local memory (the O matrix from the paper) - o4_t lo[DV4/NL]; + { + q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += ikv2*args.nb12 + ikv3*args.nb13; + v += ikv2*args.nb22 + ikv3*args.nb23; + } // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q); - for (short i = tiisg; i < DK4; i += NW) { - if (iq1 < args.ne01) { + for (short i = tiisg; i < PK4; i += NW) { + if (iq1 < args.ne01 && i < DK4) { sq4[i] = (q4_t) q4[i]; } else { sq4[i] = (q4_t) 0.0f; } } - // zero out lo + // zero out so for (short i = 0; i < DV4/NL; ++i) { - lo[i] = (o4_t) 0.0f; + so4[i*NL] = (o4_t) 0.0f; } // zero out shared memory SH @@ -4819,28 +5230,19 @@ kernel void kernel_flash_attn_ext_vec( { float S = 0.0f; - float M = -__FLT_MAX__/2; + float M = -FLT_MAX/2; // thread indices inside the simdgroup const short tx = tiisg%NL; const short ty = tiisg/NL; - // broadcast kv - //const short rk2 = args.ne02/args.ne12; - //const short rk3 = args.ne03/args.ne13; - - const short ikv2 = iq2/(args.ne02/args.ne_12_2); - const short ikv3 = iq3/(args.ne03/args.ne_12_3); - - const bool has_mask = mask != q; - // pointer to the mask device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); float slope = 1.0f; // ALiBi - if (args.max_bias > 0.0f) { + if (FC_flash_attn_ext_vec_has_bias) { const short h = iq2; const float base = h < args.n_head_log2 ? args.m0 : args.m1; @@ -4851,13 +5253,13 @@ kernel void kernel_flash_attn_ext_vec( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { + for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) { const int ic = ic0 + C*sgitg; if (ic >= args.ne11) { break; } - if (has_mask) { + if (FC_flash_attn_ext_vec_has_mask) { sm[tiisg] = pm[ic + tiisg]; } @@ -4868,69 +5270,81 @@ kernel void kernel_flash_attn_ext_vec( // Q*K^T { - // each simdgroup processes 1 query and NE (NW/NL) head elements - for (short cc = 0; cc < C/NE; ++cc) { - qk_t mqk = 0.0f; + device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11); + threadgroup const q4_t * pq4 = sq4; + + pk4 += ty*NS10/4 + tx; + pq4 += tx; - device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f }; - #pragma unroll(DK4/NL) - for (short ii = 0; ii < DK4; ii += NL) { - const short i = ii + tx; + // each simdgroup processes 1 query and NE (NW/NL) cache elements + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + if (is_same::value) { + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); + } + } else { + device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11)); k4_t mk; - deq_k_t4(pk + i/nl_k, i%nl_k, mk); - // note: this is less precise than the version below - //mqka[0] += dot(mq[0], mk[0]); - //mqka[1] += dot(mq[1], mk[1]); - //mqka[2] += dot(mq[2], mk[2]); - //mqka[3] += dot(mq[3], mk[3]); + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + const short i = ii*NL + tx; - //q4x4_t mq = sq4x4[i]; - //mqka[0] += dot((float4) mq[0], (float4) mk[0]); - //mqka[1] += dot((float4) mq[1], (float4) mk[1]); - //mqka[2] += dot((float4) mq[2], (float4) mk[2]); - //mqka[3] += dot((float4) mq[3], (float4) mk[3]); + deq_k_t4(pk + i/nl_k, i%nl_k, mk); - mqk += dot((float4) mk, (float4) sq4[i]); + mqk[cc] += dot((float4) mk, (float4) sq4[i]); + } } - static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails + if (NE == 1) { + mqk[cc] = simd_sum(mqk[cc]); + } else { + // simdgroup reduce (NE = 4) + // [ 0 .. 7] -> [ 0] + // [ 8 .. 15] -> [ 8] + // [16 .. 23] -> [16] + // [24 .. 31] -> [24] + if (NE <= 1) { + mqk[cc] += simd_shuffle_down(mqk[cc], 16); + } + if (NE <= 2) { + mqk[cc] += simd_shuffle_down(mqk[cc], 8); + } + if (NE <= 4) { + mqk[cc] += simd_shuffle_down(mqk[cc], 4); + } + if (NE <= 8) { + mqk[cc] += simd_shuffle_down(mqk[cc], 2); + } + if (NE <= 16) { + mqk[cc] += simd_shuffle_down(mqk[cc], 1); + } - // simdgroup reduce (NE = 4) - // [ 0 .. 7] -> [ 0] - // [ 8 .. 15] -> [ 8] - // [16 .. 23] -> [16] - // [24 .. 31] -> [24] - if (NE <= 1) { - mqk += simd_shuffle_down(mqk, 16); - } - if (NE <= 2) { - mqk += simd_shuffle_down(mqk, 8); - } - if (NE <= 4) { - mqk += simd_shuffle_down(mqk, 4); - } - if (NE <= 8) { - mqk += simd_shuffle_down(mqk, 2); + // broadcast + mqk[cc] = simd_shuffle(mqk[cc], NL*ty); } - if (NE <= 16) { - mqk += simd_shuffle_down(mqk, 1); - } - - // mqk = mqk*scale + mask*slope - if (tx == 0) { - mqk *= args.scale; + } - if (args.logit_softcap != 0.0f) { - mqk = args.logit_softcap*precise::tanh(mqk); - } + if (FC_flash_attn_ext_vec_has_mask && + !FC_flash_attn_ext_vec_has_scap && + !FC_flash_attn_ext_vec_has_bias) { + ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]); + } else { + mqk[tx] *= args.scale; - mqk += sm[NE*cc + ty]*slope; + if (FC_flash_attn_ext_vec_has_scap) { + mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]); + } - ss[NE*cc + ty] = mqk; + if (FC_flash_attn_ext_vec_has_bias) { + mqk[tx] += (qk_t) sm[NE*tx + ty]*slope; + } else { + mqk[tx] += (qk_t) sm[NE*tx + ty]; } + + ss[NE*tx + ty] = mqk[tx]; } } @@ -4952,9 +5366,10 @@ kernel void kernel_flash_attn_ext_vec( ss[tiisg] = vs; // O = diag(ms)*O - #pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - lo[ii/NL] *= ms; + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] *= ms; + } } } @@ -4962,26 +5377,84 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { - //#pragma unroll(C/NE) - for (short cc = 0; cc < C/NE; ++cc) { - device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + o4_t lo[DV4/NL]; + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + lo[ii] = 0.0f; + } + + if (is_same::value) { + device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21); + + pv4 += ty*NS20/4 + tx; + + const auto sst = ss + ty; + + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE])); + } + } + } else { + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21)); + + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + const short i = ii*NL + tx; + + v4_t mv; + deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + + lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty])); + } + } + } + + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + if (NE > 1) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 16); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 16); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 16); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 16); + } + + if (NE > 2) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 8); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 8); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 8); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 8); + } - const s4_t ms(ss[NE*cc + ty]); + if (NE > 4) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 4); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 4); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 4); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 4); + } - #pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - const short i = ii + tx; + if (NE > 8) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 2); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 2); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 2); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 2); + } - v4_t mv; - deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + if (NE > 16) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 1); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 1); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 1); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 1); + } + } - lo[ii/NL] += o4_t(float4(mv)*float4(ms)); + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] += lo[ii]; } } } } - if (sinks != q && sgitg == 0) { + if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) { const float m = M; const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; @@ -4992,9 +5465,10 @@ kernel void kernel_flash_attn_ext_vec( S = S*ms + simd_sum(vs); -#pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - lo[ii/NL] *= ms; + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] *= ms; + } } } @@ -5005,63 +5479,12 @@ kernel void kernel_flash_attn_ext_vec( } } - // simdgroup reduce (NE = 4) - // [ 0, 8, 16, 24] -> [ 0] - // [ 1, 9, 17, 25] -> [ 1] - // [ 2, 10, 18, 26] -> [ 2] - // [ 3, 11, 19, 27] -> [ 3] - // [ 4, 12, 20, 28] -> [ 4] - // [ 5, 13, 21, 29] -> [ 5] - // [ 6, 14, 22, 30] -> [ 6] - // [ 7, 15, 23, 31] -> [ 7] - for (short ii = 0; ii < DV4; ii += NL) { - if (NE > 1) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); - } - - if (NE > 2) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); - } - - if (NE > 4) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); - } - - if (NE > 8) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); - } - - if (NE > 16) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // store results to shared memory - for (short i = tiisg; i < DV4; i += NL) { - sr4[i] = lo[i/NL]; - } + so4 -= tiisg; threadgroup_barrier(mem_flags::mem_threadgroup); // parallel reduce - for (short r = nsg/2; r > 0; r >>= 1) { + for (short r = NSG/2; r > 0; r >>= 1) { if (sgitg < r) { const float S0 = ss[ 0]; const float S1 = ss[r*(SH/2) + 0]; @@ -5083,23 +5506,86 @@ kernel void kernel_flash_attn_ext_vec( // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 for (short i = tiisg; i < DV4; i += NW) { - sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1; + so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1; } } threadgroup_barrier(mem_flags::mem_threadgroup); } - device float4 * dst4 = (device float4 *) dst; - // final rescale with 1/S and store to global memory if (sgitg == 0) { - const float S = ss[0]; + const int64_t nrows = args.ne3*args.ne2*args.ne1; + const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1; + + device float4 * dst4 = (device float4 *) dst; + device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results + const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f; + + // interleave the workgroup data for (short i = tiisg; i < DV4; i += NW) { - dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S; + dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S; + } + + // store S and M + if (NWG > 1) { + if (tiisg == 0) { + dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0]; + dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1]; + } } } + +#undef NWG +#undef NS10 +#undef NS20 +} + +template< + typename q4_t, // query types in shared memory + typename k4_t, // key types in shared memory + typename v4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types + typename s4_t, + typename o4_t, // attention accumulation types + typename kd4_t, // key type in device memory + short nl_k, + void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), + typename vd4_t, // value type in device memory + short nl_v, + void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), + short DK, // K head size + short DV, // V head size + short NE = 4, // head elements per thread + short Q = 1, // queries per threadgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext_vec & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device const char * sinks, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C +#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg + switch (FC_flash_attn_ext_vec_nsg) { + // note: disabled cases to reduce library load time + case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + case 2: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 8: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 16: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 32: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + } +#undef FWD_TMPL +#undef FWD_ARGS } // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem @@ -5115,109 +5601,121 @@ kernel void kernel_flash_attn_ext_vec( typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; -template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; - -template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; - -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; - -template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; - -template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; - -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; - -template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES -template -kernel void kernel_set( - constant ggml_metal_kargs_set & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i13 = tgpig[2]; - const int i12 = tgpig[1]; - const int i11 = tgpig[0]; +constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]]; +constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]]; + +kernel void kernel_flash_attn_ext_vec_reduce( + constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args, + device const char * htmp, + device char * dst, + uint tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define NWG (FC_flash_attn_ext_vec_reduce_NWG) +#define DV (FC_flash_attn_ext_vec_reduce_DV) - const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10; + const uint64_t rid = tgpig; - const int64_t i3 = n / (args.ne12*args.ne11*args.ne10); - const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10); - const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10; + const short iwg = tiisg; - device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs); + device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*NWG; - for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) { - device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10); - dst_data[i10] = (T) src[0]; - } -} + float S = ss[rid*(2*NWG) + 2*iwg + 0]; + float M = ss[rid*(2*NWG) + 2*iwg + 1]; + + const float m = simd_max(M); + const float ms = exp(M - m); + + S = 1.0f/simd_sum(S*ms); + + const short DV4 = DV/4; -typedef decltype(kernel_set) kernel_set_t; + device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG; + device float4 * dst4 = (device float4 *) dst + rid*DV4; -template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set; -template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set; + for (short i = sgitg; i < DV4; i += NWG) { + const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms); + + if (iwg == 0) { + dst4[i] = v*S; + } + } + +#undef NWG +#undef DV +} template kernel void kernel_cpy( @@ -5255,12 +5753,14 @@ typedef decltype(kernel_cpy) kernel_cpy_t; template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; -#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy; +#if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; #endif template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; -#if defined(GGML_METAL_USE_BF16) +#if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy; #endif @@ -5502,7 +6002,7 @@ kernel void kernel_concat( } } -template +template void kernel_mul_mv_q2_K_f32_impl( args_t args, device const char * src0, @@ -5512,13 +6012,15 @@ void kernel_mul_mv_q2_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5602,10 +6104,10 @@ kernel void kernel_mul_mv_q2_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q3_K_f32_impl( args_t args, device const char * src0, @@ -5615,6 +6117,7 @@ void kernel_mul_mv_q3_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; @@ -5622,7 +6125,7 @@ void kernel_mul_mv_q3_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5766,10 +6269,10 @@ kernel void kernel_mul_mv_q3_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q4_K_f32_impl( args_t args, device const char * src0, @@ -5779,9 +6282,11 @@ void kernel_mul_mv_q4_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; + const short NSG = FC_mul_mv_nsg; + + constexpr uint16_t kmask1 = 0x3f3f; + constexpr uint16_t kmask2 = 0x0f0f; + constexpr uint16_t kmask3 = 0xc0c0; const short ix = tiisg/8; // 0...3 const short it = tiisg%8; // 0...7 @@ -5794,7 +6299,7 @@ void kernel_mul_mv_q4_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5840,7 +6345,7 @@ void kernel_mul_mv_q4_K_f32_impl( float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (short i = 0; i < 4; ++i) { + FOR_UNROLL (short i = 0; i < 4; ++i) { acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F); acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00); acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0); @@ -5851,14 +6356,11 @@ void kernel_mul_mv_q4_K_f32_impl( acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000); } - float dall = dh[0]; - float dmin = dh[1]; - - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + - (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + - (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + - (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); q1 += args.nb01/2; sc += args.nb01/2; @@ -5888,10 +6390,10 @@ kernel void kernel_mul_mv_q4_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q5_K_f32_impl( args_t args, device const char * src0, @@ -5901,6 +6403,7 @@ void kernel_mul_mv_q5_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; @@ -5908,7 +6411,7 @@ void kernel_mul_mv_q5_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -5923,9 +6426,9 @@ void kernel_mul_mv_q5_K_f32_impl( float yl[16], yh[16]; - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; + constexpr uint16_t kmask1 = 0x3f3f; + constexpr uint16_t kmask2 = 0x0f0f; + constexpr uint16_t kmask3 = 0xc0c0; const short tid = tiisg/4; const short ix = tiisg%4; @@ -5971,7 +6474,7 @@ void kernel_mul_mv_q5_K_f32_impl( float4 acc1 = {0.f}; float4 acc2 = {0.f}; - for (short l = 0; l < 8; ++l) { + FOR_UNROLL (short l = 0; l < 8; ++l) { uint8_t h = qh[l]; acc1[0] += yl[l+0] * (q1[l] & 0x0F); acc1[1] += yl[l+8] * (q1[l] & 0xF0); @@ -5982,13 +6485,12 @@ void kernel_mul_mv_q5_K_f32_impl( acc2[2] += h & hm3 ? yh[l+0] : 0.f; acc2[3] += h & hm4 ? yh[l+8] : 0.f; } - const float dall = dh[0]; - const float dmin = dh[1]; - sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + - sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + - sc8[4] * (acc1[2] + 16.f*acc2[2]) + - sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + sumf[row] += dh[0] * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + + sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + + sc8[4] * (acc1[2] + 16.f*acc2[2]) + + sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - + dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); q1 += args.nb01; qh += args.nb01; @@ -6019,10 +6521,10 @@ kernel void kernel_mul_mv_q5_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q6_K_f32_impl( args_t args, device const char * src0, @@ -6032,11 +6534,12 @@ void kernel_mul_mv_q6_K_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; - const uint8_t kmask1 = 0x03; - const uint8_t kmask2 = 0x0C; - const uint8_t kmask3 = 0x30; - const uint8_t kmask4 = 0xC0; + constexpr uint8_t kmask1 = 0x03; + constexpr uint8_t kmask2 = 0x0C; + constexpr uint8_t kmask3 = 0x30; + constexpr uint8_t kmask4 = 0xC0; const int nb = args.ne00/QK_K; @@ -6044,7 +6547,7 @@ void kernel_mul_mv_q6_K_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6087,18 +6590,16 @@ void kernel_mul_mv_q6_K_f32_impl( } for (short row = 0; row < nr0; ++row) { - const float dall = dh[0]; - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (short l = 0; l < 4; ++l) { + FOR_UNROLL (short l = 0; l < 4; ++l) { sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); } - sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); q1 += args.nb01; q2 += args.nb01; @@ -6128,12 +6629,12 @@ kernel void kernel_mul_mv_q6_K_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit -template +template void kernel_mul_mv_iq2_xxs_f32_impl( args_t args, device const char * src0, @@ -6143,13 +6644,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6236,10 +6739,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_xs_f32_impl( args_t args, device const char * src0, @@ -6249,13 +6752,15 @@ void kernel_mul_mv_iq2_xs_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6353,10 +6858,10 @@ kernel void kernel_mul_mv_iq2_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_xxs_f32_impl( args_t args, device const char * src0, @@ -6366,13 +6871,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6463,10 +6970,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_s_f32_impl( args_t args, device const char * src0, @@ -6476,13 +6983,15 @@ void kernel_mul_mv_iq3_s_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6573,10 +7082,10 @@ kernel void kernel_mul_mv_iq3_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_s_f32_impl( args_t args, device const char * src0, @@ -6586,13 +7095,15 @@ void kernel_mul_mv_iq2_s_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6684,10 +7195,10 @@ kernel void kernel_mul_mv_iq2_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq1_s_f32_impl( args_t args, device const char * src0, @@ -6697,13 +7208,15 @@ void kernel_mul_mv_iq1_s_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6781,10 +7294,10 @@ kernel void kernel_mul_mv_iq1_s_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq1_m_f32_impl( args_t args, device const char * src0, @@ -6794,6 +7307,7 @@ void kernel_mul_mv_iq1_m_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; const int nb = args.ne00/QK_K; @@ -6801,7 +7315,7 @@ void kernel_mul_mv_iq1_m_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6889,10 +7403,10 @@ kernel void kernel_mul_mv_iq1_m_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, device const char * src0, @@ -6902,6 +7416,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK4_NL; @@ -6910,7 +7425,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6995,10 +7510,10 @@ kernel void kernel_mul_mv_iq4_nl_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, device const char * src0, @@ -7008,13 +7523,15 @@ void kernel_mul_mv_iq4_xs_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7100,10 +7617,10 @@ kernel void kernel_mul_mv_iq4_xs_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_mxfp4_f32_impl( args_t args, device const char * src0, @@ -7113,6 +7630,7 @@ void kernel_mul_mv_mxfp4_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK_MXFP4; @@ -7121,7 +7639,7 @@ void kernel_mul_mv_mxfp4_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * nr0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7189,7 +7707,7 @@ kernel void kernel_mul_mv_mxfp4_f32( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -7312,7 +7830,7 @@ kernel void kernel_set_rows_f( const int32_t i10 = i01; const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; - device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); + device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) { @@ -7411,18 +7929,20 @@ kernel void kernel_mul_mm( #pragma unroll(4) for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(4) for (short i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } - simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) for (short i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(8) for (short i = 0; i < 8; i++){ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); @@ -7474,97 +7994,82 @@ kernel void kernel_mul_mm( } } -template +template // n_expert_used kernel void kernel_mul_mm_id_map0( constant ggml_metal_kargs_mul_mm_id_map0 & args, - device const char * src1, device const char * src2, - device char * hsrc1, device char * htpe, device char * hids, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int ide = tgpig[0]; // expert id - - int n_all = 0; - - device int32_t * ids_i32 = (device int32_t *) (hids); + threadgroup char * shmem [[threadgroup(0)]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort ntg[[threads_per_threadgroup]]) { + const short ide = tpitg; // expert id - for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens - device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21); + uint32_t n_all = 0; - for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used - if (src2_i32[i20] != ide) { - continue; - } + device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21; - device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11); - device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11); + for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens + if (i21 + tpitg < args.ne21) { + device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21); - for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) { - hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]); - } + threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20; - if (tpitg.x == 0) { - ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all; + #pragma unroll(ne20) + for (short i20 = 0; i20 < ne20; i20++) { + sids[i20] = src2_i32[i20]; } - - ++n_all; } - } - - if (tpitg.x == 0) { - device int32_t * tpe_i32 = (device int32_t *) (htpe); - tpe_i32[ide] = n_all; - } -} - -typedef decltype(kernel_mul_mm_id_map0) kernel_mul_mm_id_map0_t; -template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0; + threadgroup_barrier(mem_flags::mem_threadgroup); -template -kernel void kernel_mul_mm_id_map1( - constant ggml_metal_kargs_mul_mm_id_map1 & args, - device const char * hdst, - device const char * hids, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i20 = tgpig[0]; // used expert - const int i21 = tgpig[1]; // token + for (short t = 0; t < ntg; t++) { + if (i21 + t >= args.ne21) { + break; + } - device const int32_t * ids_i32 = (device const int32_t *) (hids); - device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2); + threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20; - const int id = ids_i32[i21*args.ne20 + i20]; + short sel = 0; + #pragma unroll(ne20) + for (short i20 = 0; i20 < ne20; i20++) { + sel += (sids[i20] == ide)*(i20 + 1); + } - const int ide = id / args.neh1; - const int idt = id % args.neh1; + ids_i32[n_all] = (i21 + t)*ne20 + sel - 1; - device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2); + n_all += sel > 0; + } - for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) { - dst_f32x4[i0] = hdst_f32x4[i0]; + threadgroup_barrier(mem_flags::mem_threadgroup); } + + device uint32_t * tpe_u32 = (device uint32_t *) (htpe); + tpe_u32[ide] = n_all; } -typedef decltype(kernel_mul_mm_id_map1) kernel_mul_mm_id_map1_t; +typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t; -template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1; +template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>; +template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>; +template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>; +template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>; +template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; +template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; +template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; template kernel void kernel_mul_mm_id( constant ggml_metal_kargs_mul_mm_id & args, device const char * src0, device const char * src1, - device const char * tpe, + device const char * htpe, + device const char * hids, device char * dst, threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { threadgroup T * sa = (threadgroup T *)(shmem); @@ -7572,19 +8077,20 @@ kernel void kernel_mul_mm_id( const int r0 = tgpig.y; const int r1 = tgpig.x; - const int im = tgpig.z; + const int im = tgpig.z; // expert - device const int32_t * tpe_i32 = (device const int32_t *) (tpe); + device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe); + device const int32_t * ids_i32 = (device const int32_t *) (hids); - const int neh1 = tpe_i32[im]; + const int32_t neh1 = tpe_u32[im]; if (r1*BLOCK_SIZE_N >= neh1) { return; } // if this block is of 64x32 shape or smaller - const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; - const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; // a thread shouldn't load data outside of the matrix const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; @@ -7600,20 +8106,23 @@ kernel void kernel_mul_mm_id( short il = (tiitg % THREAD_PER_ROW); - const int i12 = im%args.neh12; - const int i13 = im/args.neh12; + const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col]; - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short i11 = (id % args.ne20) % args.ne11; + const short i12 = (id / args.ne20); + const short i13 = 0; + + const uint64_t offset0 = im*args.nb02 + i13*args.nb03; const short offset1 = il/nl; device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - device const half * y = (device const half *)(src1 - + args.nbh13*i13 - + args.nbh12*i12 - + args.nbh11*(r1*BLOCK_SIZE_N + thread_col) - + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + device const float * y = (device const float *)(src1 + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*i11 + + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory @@ -7629,7 +8138,7 @@ kernel void kernel_mul_mm_id( + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; } - *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y); + *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y)); il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; @@ -7665,43 +8174,38 @@ kernel void kernel_mul_mm_id( } } - if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) { - device float * C = (device float *) dst + - (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ - (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0; + threadgroup_barrier(mem_flags::mem_threadgroup); - for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0); - } - } else { - // block is smaller than 64x32, we should avoid writing data outside of the matrix - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *) shmem) \ - + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; - for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); - } + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; - threadgroup_barrier(mem_flags::mem_threadgroup); + #pragma unroll(8) + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + } - if (sgitg == 0) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0; - device float4 * D4 = (device float4 *) D; + threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); - threadgroup float4 * C4 = (threadgroup float4 *) C; + for (short j = sgitg; j < n_cols; j += 4) { + const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j]; - int i = 0; - for (; i < n_rows/4; i++) { - *(D4 + i) = *(C4 + i); - } + const short ide = id % args.ne20; + const short idt = id / args.ne20; - i *= 4; - for (; i < n_rows; i++) { - *(D + i) = *(C + i); - } - } + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = tiisg; + for (; i < n_rows/4; i += 32) { + *(D4 + i) = *(C4 + i); + } + + i = (4*(n_rows/4)) + tiisg; + for (; i < n_rows; i += 32) { + *(D + i) = *(C + i); } } } @@ -7716,7 +8220,7 @@ typedef decltype(kernel_get_rows_f) get_rows_f_t; template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; -#if defined(GGML_METAL_USE_BF16) +#if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; #endif @@ -7751,7 +8255,7 @@ typedef decltype(kernel_set_rows_f) set_rows_f_t; template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f; template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f; -#if defined(GGML_METAL_USE_BF16) +#if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f; #endif @@ -7772,7 +8276,7 @@ typedef decltype(kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm; -#if defined(GGML_METAL_USE_BF16) +#if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; #endif template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; @@ -7804,7 +8308,7 @@ typedef decltype(kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; -#if defined(GGML_METAL_USE_BF16) +#if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id; #endif template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; @@ -7879,7 +8383,7 @@ void mmv_fn( impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(mmv_fn>) mul_mv_impl_fn_t; +typedef decltype(mmv_fn>) mul_mv_impl_fn_t; template kernel void kernel_mul_mv_id( @@ -7944,44 +8448,52 @@ kernel void kernel_mul_mv_id( sgitg); } -typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; + +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_4_t; -template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -#if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; #endif -template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; - -template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; - -template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; - -template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +#endif + +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; kernel void kernel_pool_2d_max_f32( + constant ggml_metal_kargs_pool_2d & args, device const float * src0, device float * dst, - constant ggml_metal_kargs_pool_2d & args, uint gid[[thread_position_in_grid]]) { - if (gid >= args.parallel_elements) { + if (gid >= args.np) { return; } @@ -8014,12 +8526,12 @@ kernel void kernel_pool_2d_max_f32( } kernel void kernel_pool_2d_avg_f32( + constant ggml_metal_kargs_pool_2d & args, device const float * src0, device float * dst, - constant ggml_metal_kargs_pool_2d & args, uint gid[[thread_position_in_grid]]) { - if (gid >= args.parallel_elements) { + if (gid >= args.np) { return; } diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 02904526ade..cdb3818c786 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -96,10 +96,6 @@ if (MUSAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_FA) endif() - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - add_compile_definitions(GGML_CUDA_F16) - endif() - if (GGML_CUDA_NO_PEER_COPY) add_compile_definitions(GGML_CUDA_NO_PEER_COPY) endif() diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 9a7ccbcff06..1c06aa138bf 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -83,8 +83,10 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_0_f32_1d_16x_flat mul_mv_q6_k mul_mv_mxfp4_f32 + mul_mv_mxfp4_f32_flat mul_mv_id_q4_0_f32_8x_flat mul_mv_id_mxfp4_f32 + mul_mv_id_mxfp4_f32_flat mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm mul diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 8a2ac7e377d..2cb838b7139 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -333,6 +333,7 @@ struct ggml_backend_opencl_context { cl_int alignment; size_t max_alloc_size; + size_t max_workgroup_size; bool fp16_support; bool has_vector_subgroup_broadcast; bool disable_fusion; @@ -367,6 +368,7 @@ struct ggml_backend_opencl_context { cl_program program_mul_mv_q4_0_f32_1d_16x_flat; cl_program program_mul_mv_q6_K; cl_program program_mul_mv_mxfp4_f32; + cl_program program_mul_mv_mxfp4_f32_flat; cl_program program_mul_mv_f16_f16; cl_program program_mul_mv_f16_f32_1row; cl_program program_mul_mv_f16_f32_l4; @@ -401,6 +403,7 @@ struct ggml_backend_opencl_context { cl_program program_tsembd; cl_program program_mul_mv_id_q4_0_f32_8x_flat; cl_program program_mul_mv_id_mxfp4_f32; + cl_program program_mul_mv_id_mxfp4_f32_flat; cl_program program_mul_mm_f32_f32_l4_lm; cl_program program_mul_mm_f16_f32_l4_lm; @@ -419,9 +422,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_clamp; cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick, kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16; - cl_kernel kernel_norm; + cl_kernel kernel_norm, kernel_norm_mul_add; cl_kernel kernel_rms_norm, kernel_rms_norm_mul; - cl_kernel kernel_group_norm; + cl_kernel kernel_group_norm, kernel_group_norm_mul_add; cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; cl_kernel kernel_soft_max, kernel_soft_max_4; cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; @@ -446,11 +449,12 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_f16_f32_tiled; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; + cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q6_K_f32; - cl_kernel kernel_mul_mv_mxfp4_f32; + cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; cl_kernel kernel_sum_rows_f32; @@ -468,6 +472,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_timestep_embedding; cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; cl_kernel kernel_mul_mv_id_mxfp4_f32; + cl_kernel kernel_mul_mv_id_mxfp4_f32_flat; cl_kernel kernel_mul_mm_f32_f32_l4_lm; cl_kernel kernel_mul_mm_f16_f32_l4_lm; @@ -764,6 +769,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err)); GGML_LOG_CONT("."); } @@ -1001,6 +1008,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_mxfp4_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_mxfp4_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_mxfp4_f32_flat.cl"); +#endif + backend_ctx->program_mul_mv_mxfp4_f32_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_mxfp4_f32_flat, "kernel_mul_mv_mxfp4_f32_flat", &err), err)); + GGML_LOG_CONT("."); + } + // mul_mv_f16_f16 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1160,7 +1183,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve backend_ctx->program_norm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_norm_mul_add = clCreateKernel(backend_ctx->program_norm, "kernel_norm_mul_add", &err), err)); GGML_LOG_CONT("."); } @@ -1337,7 +1361,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { - { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, + { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, {192, 192, 16, 16}, {256, 256, 16, 16}, }; @@ -1486,7 +1510,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve backend_ctx->program_group_norm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_group_norm_mul_add = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm_mul_add", &err), err)); GGML_LOG_CONT("."); } @@ -1724,6 +1749,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_id_mxfp4_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_id_mxfp4_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_id_mxfp4_f32_flat.cl"); +#endif + backend_ctx->program_mul_mv_id_mxfp4_f32_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_id_mxfp4_f32_flat, "kernel_mul_mv_id_mxfp4_f32_flat", &err), err)); + GGML_LOG_CONT("."); + } + // Adreno kernels #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // transpose @@ -2218,6 +2259,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024); + clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL); + GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size); + // Check SVM. cl_device_svm_capabilities svm_caps; CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0)); @@ -2385,6 +2429,51 @@ struct ggml_tensor_extra_cl_q4_0 { } }; +struct ggml_tensor_extra_cl_mxfp4 { + // Quantized values. + cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; + // Scales in E8M0. + cl_mem e = nullptr; + // Scales in image1d_buffer_t. + cl_mem e_img = nullptr; + // Size of quantized values. + size_t size_q = 0; + // Size of scales. + size_t size_e = 0; + + ~ggml_tensor_extra_cl_mxfp4() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (e != nullptr) { + CL_CHECK(clReleaseMemObject(e)); + e = nullptr; + } + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q = nullptr; + } + // Currently, q_img and d_img are only initialized when SMALL_ALLOC is + // enabled. They point to the images in ggml_backend_opencl_buffer_context. + // So, there is no need to release them here. + // TODO: initialize them for non SMALL_PATH path, or remove them. + q_img = nullptr; + e_img = nullptr; + size_q = 0; + size_e = 0; + } +}; + //------------------------------------------------------------------------------ // Backend API //------------------------------------------------------------------------------ @@ -2494,12 +2583,47 @@ static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { return false; } + } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) { + const ggml_tensor *norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + const ggml_tensor *add = cgraph->nodes[node_idx+2]; + const ggml_tensor *w = mul->src[0] == norm ? mul->src[1] : mul->src[0]; + const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0]; + + // norm fusion only supports F32 + if (norm->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) { + return false; + } + + if (norm->src[0]->ne[0] % 4 != 0) { + return false; + } + + if (!ggml_is_contiguous(norm->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) { + return false; + } + } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_GROUP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) { + const ggml_tensor *gn = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + const ggml_tensor *add = cgraph->nodes[node_idx+2]; + const ggml_tensor *w = mul->src[0] == gn ? mul->src[1] : mul->src[0]; + const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0]; + + if (gn->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) { + return false; + } + + if (!ggml_is_contiguous(gn->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) { + return false; + } } return true; } static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor); +static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor); +static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor); static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -2516,6 +2640,16 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm continue; } + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) { + ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + continue; + } + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) { + ggml_opencl_op_group_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + continue; + } if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]); i++; @@ -2533,7 +2667,8 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm } static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - GGML_UNUSED(dev); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context; + ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; switch (op->op) { case GGML_OP_NONE: @@ -2642,13 +2777,16 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SOFT_MAX: case GGML_OP_NORM: - case GGML_OP_RMS_NORM: return true; + case GGML_OP_RMS_NORM: + return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_REPEAT: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded case GGML_OP_PAD: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && - op->src[0]->ne[3] == 1 && op->ne[3] == 1; + op->src[0]->ne[3] == 1 && op->ne[3] == 1 && + (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && + (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: @@ -2708,16 +2846,21 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } case GGML_OP_IM2COL: return true; - case GGML_OP_ARGSORT: - return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ARGSORT: { + cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + + int cols = 1; + while (cols < op->ne[0]) { + cols *= 2; + } + + return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32; + } case GGML_OP_SUM_ROWS: return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_FLASH_ATTN_EXT: { - if (op->src[4]) { - return false; - } - const ggml_tensor * q = op->src[0]; const ggml_tensor * k = op->src[1]; const ggml_tensor * v = op->src[2]; @@ -2726,7 +2869,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te const int dv = v->ne[0]; const struct { int dk; int dv; } supported_dims[] = { - { 64, 64}, { 80, 80}, { 96, 96}, + { 40, 40}, { 64, 64}, { 80, 80}, { 96, 96}, {112, 112}, {128, 128}, {192, 128}, {192, 192}, {256, 256}, }; @@ -2778,6 +2921,7 @@ static ggml_backend_i ggml_backend_opencl_i = { /* .graph_compute = */ ggml_backend_opencl_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, }; ggml_backend_t ggml_backend_opencl_init(void) { @@ -2833,6 +2977,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { delete e; } + for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) { + delete e; + } + for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) { + delete e; + } } ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() { @@ -2865,6 +3015,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() { + ggml_tensor_extra_cl_mxfp4 * extra; + if (temp_tensor_extras_mxfp4.empty()) { + extra = new ggml_tensor_extra_cl_mxfp4(); + } else { + extra = temp_tensor_extras_mxfp4.back(); + temp_tensor_extras_mxfp4.pop_back(); + } + + temp_tensor_extras_mxfp4_in_use.push_back(extra); + + extra->reset(); + return extra; + } + void reset() { for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) { temp_tensor_extras.push_back(e); @@ -2875,6 +3040,11 @@ struct ggml_backend_opencl_buffer_context { temp_tensor_extras_q4_0.push_back(e); } temp_tensor_extras_q4_0_in_use.clear(); + + for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) { + temp_tensor_extras_mxfp4.push_back(e); + } + temp_tensor_extras_mxfp4_in_use.clear(); } // Pools for extras. Available extras are in `temp_tensor_extras`. Extras @@ -2886,6 +3056,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_in_use; std::vector temp_tensor_extras_q4_0; std::vector temp_tensor_extras_q4_0_in_use; + std::vector temp_tensor_extras_mxfp4; + std::vector temp_tensor_extras_mxfp4_in_use; // The buffer_context is initially created by ggml_backend_buft_alloc_buffer // before any tensor is initialized (at the beginning of alloc_tensor_range). @@ -3228,6 +3400,76 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, } #endif // GGML_OPENCL_USE_ADRENO_KERNELS + return; + + } + if (tensor->type == GGML_TYPE_MXFP4) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_mxfp4 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_mxfp4(); + + size_t size_e = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(char); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_e + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, then quants. + cl_buffer_region region; + + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_e; + extra->e = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_e, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_q = {CL_RG, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor)/32*2), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + + tensor->extra = extra; + return; } #endif // GGML_OPENCL_SOA_Q @@ -3276,6 +3518,31 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {1, 1, 1}; + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } else if (tensor->type == GGML_TYPE_MXFP4) { + ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + cl_event evt; CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); @@ -3597,6 +3864,19 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL)); CL_CHECK(clFinish(queue)); + } else if (tensor->type == GGML_TYPE_MXFP4) { + ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra; + GGML_ASSERT(extra); + + size_t size_q = ggml_nelements(tensor)/QK_MXFP4 * QK_MXFP4/2; + size_t size_e = ggml_nelements(tensor)/QK_MXFP4 * sizeof(char); + GGML_ASSERT(size_q + size_e == ggml_nbytes(tensor)); + buf_q = malloc(size_q); + buf_d = malloc(size_e); + + CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); } else { // Read out the tensor from GPU memory. ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; @@ -5024,6 +5304,140 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) { + GGML_ASSERT(norm_tensor && mul_tensor && add_tensor); + + const ggml_tensor * src0 = norm_tensor->src[0]; + const ggml_tensor * src1 = mul_tensor->src[0] == norm_tensor ? mul_tensor->src[1] : mul_tensor->src[0]; + const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0]; + const ggml_tensor * dst = add_tensor; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offset2 = extra2->offset + src2->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + float eps; + memcpy(&eps, norm_tensor->op_params, sizeof(float)); + + const int ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3]; + const cl_ulong nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3]; + const int ne10 = src1->ne[0], ne11 = src1->ne[1], ne12 = src1->ne[2], ne13 = src1->ne[3]; + const cl_ulong nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3]; + const int ne20 = src2->ne[0], ne21 = src2->ne[1], ne22 = src2->ne[2], ne23 = src2->ne[3]; + const cl_ulong nb21 = src2->nb[1], nb22 = src2->nb[2], nb23 = src2->nb[3]; + const cl_ulong nbd1 = dst->nb[1], nbd2 = dst->nb[2], nbd3 = dst->nb[3]; + + size_t sgs; + if (backend_ctx->gpu_family == ADRENO) sgs = 64; + else if (backend_ctx->gpu_family == INTEL) sgs = 32; + else GGML_ASSERT(false && "Unsupported GPU"); + + cl_kernel kernel = backend_ctx->kernel_norm_mul_add; + + int nth = sgs; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2; + nth = MIN(nth, max_workgroup_size); + nth = MIN(nth, ne00/4); + + size_t gws[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t lws[] = {(size_t)nth, 1, 1}; + size_t num_subgroups = (nth + sgs - 1) / sgs; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne22)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne23)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb21)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb22)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb23)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nbd1)); + CL_CHECK(clSetKernelArg(kernel, 30, sizeof(cl_ulong), &nbd2)); + CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_ulong), &nbd3)); + CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_float2) * num_subgroups, NULL)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, gws, lws, dst); +} + +static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) { + GGML_ASSERT(gn_tensor && mul_tensor && add_tensor); + + const ggml_tensor * src0 = gn_tensor->src[0]; + const ggml_tensor * src1 = mul_tensor->src[0] == gn_tensor ? mul_tensor->src[1] : mul_tensor->src[0]; + const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0]; + const ggml_tensor * dst = add_tensor; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offset2 = extra2->offset + src2->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + int groups; + float eps; + memcpy(&groups, gn_tensor->op_params, sizeof(int)); + memcpy(&eps, (char *)gn_tensor->op_params + sizeof(int), sizeof(float)); + + cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + int ne = ggml_nelements(src0); + int group_size = ne / groups; + + size_t lws[] = { (size_t)MIN(max_workgroup_size, group_size) }; + size_t gws[] = { (size_t)groups * lws[0] }; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &group_size)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &eps)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst); +} + static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -5569,6 +5983,7 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) { const ggml_tensor * v = dst->src[2]; const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; GGML_ASSERT(q->extra); GGML_ASSERT(k->extra); GGML_ASSERT(v->extra); @@ -5576,6 +5991,9 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co if (mask) { GGML_ASSERT(mask->extra); } + if (sinks) { + GGML_ASSERT(sinks->extra); + } ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -5617,6 +6035,7 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra; ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra; ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL; + ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL; cl_ulong offset_q = extra_q->offset + q->view_offs; cl_ulong offset_k = extra_k->offset + k->view_offs; @@ -5624,6 +6043,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co cl_ulong offset_o = extra_o->offset + dst->view_offs; cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL; cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0; + cl_mem sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL; + cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0; const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3]; const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3]; @@ -5678,6 +6099,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3)); CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2)); CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3)); + CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem), &sinks_buffer)); + CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks)); if (n_q == 1) { const size_t wg_size = 64; @@ -5844,6 +6267,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; #endif const int ne00 = src0 ? src0->ne[0] : 0; @@ -6548,6 +6972,45 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); break; case GGML_TYPE_MXFP4: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_mxfp4_f32_flat; + + cl_mem q; + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 2; + ndst = nth1*2; + + q = extra0_mxfp4->q; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = nth1*2; + + q = extra0_mxfp4->q_img; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_mxfp4->e)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r3)); +#else kernel = backend_ctx->kernel_mul_mv_mxfp4_f32; if (backend_ctx->gpu_family == INTEL) { @@ -6581,6 +7044,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r3)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float)*nth0,nullptr)); +#endif break; } default: @@ -6646,8 +7110,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, cl_ulong offset2 = extra2->offset + src2->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; + GGML_UNUSED(offset0); + #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; #endif const int ne00 = src0->ne[0]; @@ -6736,6 +7203,51 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, break; } case GGML_TYPE_MXFP4: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat; + + cl_mem q; + if (backend_ctx->gpu_family == INTEL) { + sgs = 16; + nsg = 2; + ndst = 2; + + q = extra0_mxfp4->q; + } else if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + nsg = 1; + ndst = 4; + + q = extra0_mxfp4->q_img; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_mxfp4->e)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb21)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3)); +#else // GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32; if (backend_ctx->gpu_family == INTEL) { @@ -6775,7 +7287,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs,nullptr)); - +#endif // GGML_OPENCL_SOA_Q break; } default: diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index fe7975e3dbf..3440ff50796 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -116,3 +116,49 @@ kernel void kernel_convert_block_q4_0_noshuffle( #endif } } + + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +#define QK_MXFP4 32 +struct block_mxfp4 { + uchar e; // E8M0 + uchar qs[QK_MXFP4 / 2]; +}; + +//------------------------------------------------------------------------------ +// kernel_convert_block_mxfp4 +// Convert the block_mxfp4 format to 2 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_mxfp4( + global struct block_mxfp4 * src0, + global uchar * dst_q, + global uchar * dst_e +) { + global struct block_mxfp4 * b = (global struct block_mxfp4 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_MXFP4 / 2 * get_global_id(0); + global uchar * e = (global uchar *) dst_e + get_global_id(0); + + *e = b->e; + + for (int i = 0; i < QK_MXFP4 / 2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_mxfp4( + global uchar * src_q, + global half * src_e, + global struct block_mxfp4 * dst +) { + global struct block_mxfp4 * b = (global struct block_mxfp4 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_MXFP4 / 2 * get_global_id(0); + global uchar * e = (global uchar *) src_e + get_global_id(0); + + b->e = *e; + for (int i = 0; i < QK_MXFP4 / 2; ++i) { + b->qs[i] = q[i]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl index fea06867e10..8f43c4f27d5 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl @@ -49,7 +49,9 @@ __kernel void flash_attn_f16( const ulong mask_nb2, const ulong mask_nb3, const int mask_ne2, - const int mask_ne3 + const int mask_ne3, + const global void* sinks_void, + const ulong sinks_offset ) { const int tid = get_local_id(0); const int block_q_idx = get_group_id(0); @@ -171,6 +173,20 @@ __kernel void flash_attn_f16( } if (my_query_row < n_q) { + if (sinks_void != NULL) { + const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); + const ACC_TYPE m_sink = sinks_ptr[head_idx]; + const ACC_TYPE m_final = max(m_i, m_sink); + + const ACC_TYPE scale_o = exp(m_i - m_final); + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) { + o_acc[i] *= scale_o; + } + + l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final); + } + const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1; global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); if (l_i > 0.0f) { @@ -214,7 +230,9 @@ __kernel void flash_attn_f16_q1( const ulong mask_nb2, const ulong mask_nb3, const int mask_ne2, - const int mask_ne3 + const int mask_ne3, + const global void* sinks_void, + const ulong sinks_offset ) { const int tid = get_local_id(0); const int head_batch_idx = get_global_id(1); @@ -247,7 +265,12 @@ __kernel void flash_attn_f16_q1( float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); - ACC_TYPE m_i = -INFINITY; + const global ACC_TYPE* sinks_ptr = NULL; + if (sinks_void != NULL) { + sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); + } + + ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY; for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); @@ -320,7 +343,11 @@ __kernel void flash_attn_f16_q1( const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1; global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); - const ACC_TYPE l_final = local_l[0]; + ACC_TYPE l_final = local_l[0]; + + if (sinks_ptr != NULL) { + l_final += exp(sinks_ptr[head_idx] - m_final); + } if (l_final > 0.0f) { const ACC_TYPE l_inv = 1.0f / l_final; diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl index 2d657327d64..9c0bab135a9 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl @@ -49,7 +49,9 @@ __kernel void flash_attn_f32( const ulong mask_nb2, const ulong mask_nb3, const int mask_ne2, - const int mask_ne3 + const int mask_ne3, + const global void* sinks_void, + const ulong sinks_offset ) { const int tid = get_local_id(0); const int block_q_idx = get_group_id(0); @@ -171,6 +173,20 @@ __kernel void flash_attn_f32( } if (my_query_row < n_q) { + if (sinks_void != NULL) { + const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); + const ACC_TYPE m_sink = sinks_ptr[head_idx]; + const ACC_TYPE m_final = max(m_i, m_sink); + + const ACC_TYPE scale_o = exp(m_i - m_final); + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) { + o_acc[i] *= scale_o; + } + + l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final); + } + const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1; global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); if (l_i > 0.0f) { @@ -214,7 +230,9 @@ __kernel void flash_attn_f32_q1( const ulong mask_nb2, const ulong mask_nb3, const int mask_ne2, - const int mask_ne3 + const int mask_ne3, + const global void* sinks_void, + const ulong sinks_offset ) { const int tid = get_local_id(0); const int head_batch_idx = get_global_id(1); @@ -247,7 +265,12 @@ __kernel void flash_attn_f32_q1( float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); - ACC_TYPE m_i = -INFINITY; + const global ACC_TYPE* sinks_ptr = NULL; + if (sinks_void != NULL) { + sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); + } + + ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY; for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); @@ -320,7 +343,11 @@ __kernel void flash_attn_f32_q1( const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1; global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); - const ACC_TYPE l_final = local_l[0]; + ACC_TYPE l_final = local_l[0]; + + if (sinks_ptr != NULL) { + l_final += exp(sinks_ptr[head_idx] - m_final); + } if (l_final > 0.0f) { const ACC_TYPE l_inv = 1.0f / l_final; diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl index 7067bd2591f..ec7361b9e37 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl @@ -52,7 +52,9 @@ __kernel void flash_attn_f32_f16( const ulong mask_nb2, const ulong mask_nb3, const int mask_ne2, - const int mask_ne3 + const int mask_ne3, + const global void* sinks_void, + const ulong sinks_offset ) { const int tid = get_local_id(0); const int block_q_idx = get_group_id(0); @@ -174,6 +176,20 @@ __kernel void flash_attn_f32_f16( } if (my_query_row < n_q) { + if (sinks_void != NULL) { + const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); + const ACC_TYPE m_sink = sinks_ptr[head_idx]; + const ACC_TYPE m_final = max(m_i, m_sink); + + const ACC_TYPE scale_o = exp(m_i - m_final); + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) { + o_acc[i] *= scale_o; + } + + l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final); + } + const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1; global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset); if (l_i > 0.0f) { @@ -217,7 +233,9 @@ __kernel void flash_attn_f32_f16_q1( const ulong mask_nb2, const ulong mask_nb3, const int mask_ne2, - const int mask_ne3 + const int mask_ne3, + const global void* sinks_void, + const ulong sinks_offset ) { const int tid = get_local_id(0); const int head_batch_idx = get_global_id(1); @@ -250,7 +268,12 @@ __kernel void flash_attn_f32_f16_q1( float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); - ACC_TYPE m_i = -INFINITY; + const global ACC_TYPE* sinks_ptr = NULL; + if (sinks_void != NULL) { + sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); + } + + ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY; for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset); @@ -323,7 +346,11 @@ __kernel void flash_attn_f32_f16_q1( const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1; global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset); - const ACC_TYPE l_final = local_l[0]; + ACC_TYPE l_final = local_l[0]; + + if (sinks_ptr != NULL) { + l_final += exp(sinks_ptr[head_idx] - m_final); + } if (l_final > 0.0f) { const ACC_TYPE l_inv = 1.0f / l_final; diff --git a/ggml/src/ggml-opencl/kernels/group_norm.cl b/ggml/src/ggml-opencl/kernels/group_norm.cl index 57c9df4d35b..8e4fa0ed12d 100644 --- a/ggml/src/ggml-opencl/kernels/group_norm.cl +++ b/ggml/src/ggml-opencl/kernels/group_norm.cl @@ -70,3 +70,52 @@ kernel void kernel_group_norm( dst[j] *= scale; } } + +//------------------------------------------------------------------------------ +// group_norm_mul_add +//------------------------------------------------------------------------------ +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_group_norm_mul_add( + global float * src0, ulong offset0, + global float * src1, ulong offset1, + global float * src2, ulong offset2, + global float * dst, ulong offsetd, + int ne, + int group_size, + float eps +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global float *)((global char *)src1 + offset1); + src2 = (global float *)((global char *)src2 + offset2); + dst = (global float *)((global char *)dst + offsetd); + + int start = get_group_id(0) * group_size; + int end = start + group_size; + if (end > ne) { + end = ne; + } + + float sum = 0.0f; + float sum_sq = 0.0f; + + for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) { + float val = src0[j]; + sum += val; + sum_sq += val*val; + } + + sum = sub_group_reduce_add(sum); + sum_sq = sub_group_reduce_add(sum_sq); + + const float mean = sum / group_size; + const float var = sum_sq / group_size - mean * mean; + const float scale = rsqrt(var + eps); + + for (int j = start + get_local_id(0); j < end; j += get_local_size(0)) { + dst[j] = ((src0[j] - mean) * scale) * src1[j] + src2[j]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl new file mode 100644 index 00000000000..f65e86ed6a2 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl @@ -0,0 +1,176 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK_MXFP4 32 + +static inline half4 mxfp4_to_fp16_packed(ushort fp4x4) { + ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00; + fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00; + fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00; + fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800; + bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800; + bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800; + bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800; + + fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo; + fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi; + fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo; + fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi; + + sign_a.lo = (fp4x4 << 12) & 0x8000; + sign_a.hi = (fp4x4 << 8) & 0x8000; + sign_b.lo = (fp4x4 << 4) & 0x8000; + sign_b.hi = fp4x4 & 0x8000; + + fp16_packed_a = sign_a + bias_a + fp16_packed_a; + fp16_packed_b = sign_b + bias_b + fp16_packed_b; + + return as_half4((ushort4)(fp16_packed_a, fp16_packed_b)); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + +#ifdef INTEL_GPU +#define N_R0_MXFP4 2 // number of rows each subgroup works on +#define N_SG_MXFP4 2 // number of subgroups in a work group +#define N_SIMDWIDTH 16 // subgroup size +#elif defined (ADRENO_GPU) +#define N_R0_MXFP4 4 +#define N_SG_MXFP4 1 +#define N_SIMDWIDTH 64 +#define SRC0Q_IMG +#endif + +kernel void kernel_mul_mv_id_mxfp4_f32_flat( +#ifdef SRC0Q_IMG + __read_only image1d_buffer_t src0_q, +#else + global uchar * src0_q, +#endif + global uchar * src0_e, + global uchar * src1, + ulong offset1, + global uchar * src2, + ulong offset2, + global uchar * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne11, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne20, + int ne21, + ulong nb21, + int ne0, + int ne1, + int r2, + int r3 +) { + dst = dst + offsetd; + + const int iid1 = get_group_id(2) / ne20; + const int idx = get_group_id(2) % ne20; + + uint i02 = ((global uint *) (src2 + offset2 + iid1 * nb21))[idx]; + + int i11 = idx % ne11; + + int nb = ne00 / QK_MXFP4; + + uint src0_off = i02*nb02; + src0_off /= 17; // 17 = sizeof(block_mxfp4) + + src0_e = src0_e + src0_off; + + dst = dst + (idx * ne0 + iid1 * ne1 * ne0) * sizeof(float); + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + + int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4; + + uint offset_src0 = first_row*nb01; + offset_src0 /= 17; // 17 = sizeof(block_mxfp4) +#ifdef SRC0Q_IMG + ulong offset_q = src0_off + offset_src0; +#else + src0_q = src0_q + src0_off*16; + global uchar16 * x_q = (global uchar16 *)(src0_q) + offset_src0; +#endif + global uchar * x_e = src0_e + offset_src0; + + const short ix = get_sub_group_local_id() >> 1; + const short it = get_sub_group_local_id() & 1; + + float sumf[N_R0_MXFP4] = {0.f}; + + src1 = src1 + offset1 + i11 * nb11 + iid1 * nb12; + global float * y = (global float *) (src1 + r1 * nb11); + global float * yb = y + ix * QK_MXFP4 + it * 8; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH / 2) { + global float4 * y4 = (global float4 *)yb; + + #pragma unroll + for (short row = 0; row < N_R0_MXFP4; row++) { + uchar xb_e = x_e[row * nb + ib]; +#ifdef SRC0Q_IMG + ushort4 xb_q = as_ushort4(read_imageui(src0_q, (offset_q + row * nb + ib) * 2 + it).xy); +#else + ushort4 xb_q = vload4(0, (global ushort *)((global uchar *)(x_q + row * nb + ib) + 8 * it)); +#endif + + half4 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s0); + half4 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s1); + float4 acc1 = y4[0] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2); + acc1 += y4[4] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3); + + fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s2); + fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s3); + acc1 += y4[1] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2); + acc1 += y4[5] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3); + + sumf[row] += e8m0_to_fp32(xb_e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3)); + } + + yb += (N_SIMDWIDTH / 2) * QK_MXFP4; + } + + global float * dst_f32 = (global float *)dst + (ulong)r1 * ne0; + + for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) { + float sum_all = sub_group_reduce_add(sumf[row]); + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl new file mode 100644 index 00000000000..3d5a923eee0 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl @@ -0,0 +1,167 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK_MXFP4 32 + +static inline half4 mxfp4_to_fp16_packed(ushort fp4x4) { + ushort2 fp16_packed_a, fp16_packed_b, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a.lo = (fp4x4 << 9) & 0x0E00; + fp16_packed_a.hi = (fp4x4 << 5) & 0x0E00; + fp16_packed_b.lo = (fp4x4 << 1) & 0x0E00; + fp16_packed_b.hi = (fp4x4 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a.lo == 0) ? 0x0 : 0x3800; + bias_a.hi = (fp16_packed_a.hi == 0) ? 0x0 : 0x3800; + bias_b.lo = (fp16_packed_b.lo == 0) ? 0x0 : 0x3800; + bias_b.hi = (fp16_packed_b.hi == 0) ? 0x0 : 0x3800; + + fp16_packed_a.lo = (fp16_packed_a.lo == 0x0200) ? 0x0 : fp16_packed_a.lo; + fp16_packed_a.hi = (fp16_packed_a.hi == 0x0200) ? 0x0 : fp16_packed_a.hi; + fp16_packed_b.lo = (fp16_packed_b.lo == 0x0200) ? 0x0 : fp16_packed_b.lo; + fp16_packed_b.hi = (fp16_packed_b.hi == 0x0200) ? 0x0 : fp16_packed_b.hi; + + sign_a.lo = (fp4x4 << 12) & 0x8000; + sign_a.hi = (fp4x4 << 8) & 0x8000; + sign_b.lo = (fp4x4 << 4) & 0x8000; + sign_b.hi = fp4x4 & 0x8000; + + fp16_packed_a = sign_a + bias_a + fp16_packed_a; + fp16_packed_b = sign_b + bias_b + fp16_packed_b; + + return as_half4((ushort4)(fp16_packed_a, fp16_packed_b)); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + +#ifdef INTEL_GPU +#define N_R0_MXFP4 2 // number of rows each subgroup works on +#define N_SG_MXFP4 2 // number of subgroups in a work group +#define N_SIMDWIDTH 16 // subgroup size +#elif defined (ADRENO_GPU) +#define N_R0_MXFP4 2 +#define N_SG_MXFP4 2 +#define N_SIMDWIDTH 64 +#define SRC0Q_IMG +#endif + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_mxfp4_f32_flat( +#ifdef SRC0Q_IMG + __read_only image1d_buffer_t src0_q, +#else + global uchar * src0_q, +#endif + global uchar * src0_e, + global uchar * src1, + ulong offset1, + global uchar * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = src1 + offset1; + dst = dst + offsetd; + + int nb = ne00 / QK_MXFP4; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4; + + uint i12 = im % ne12; + uint i13 = im / ne12; + + uint offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + // 17 = sizeof(block_mxfp4) + offset_src0 /= 17; +#ifdef SRC0Q_IMG + ulong offset_q = offset_src0; +#else + global uchar16 * x_q = (global uchar16 *)(src0_q) + offset_src0; +#endif + global uchar * x_e = src0_e + offset_src0; + + ulong offset_src1 = r1 * nb11 + i12 * nb12 + i13 * nb13; + global float * y = (global float *)(src1 + offset_src1); + + const short ix = get_sub_group_local_id() >> 1; // 0...15 + const short it = get_sub_group_local_id() & 1; // 0 or 1 + + float sumf[N_R0_MXFP4] = {0.f}; + + global float * yb = y + ix * QK_MXFP4 + it * 8; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + global float4 * y4 = (global float4 *)yb; + + #pragma unroll + for (short row = 0; row < N_R0_MXFP4; row++) { + uchar xb_e = x_e[row * nb + ib]; +#ifdef SRC0Q_IMG + ushort4 xb_q = as_ushort4(read_imageui(src0_q, (offset_q + row * nb + ib) * 2 + it).xy); +#else + ushort4 xb_q = vload4(0, (global ushort *)((global uchar *)(x_q + row * nb + ib) + 8 * it)); +#endif + + half4 fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s0); + half4 fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s1); + float4 acc1 = y4[0] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2); + acc1 += y4[4] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3); + + fp16x4_0 = mxfp4_to_fp16_packed(xb_q.s2); + fp16x4_1 = mxfp4_to_fp16_packed(xb_q.s3); + acc1 += y4[1] * (float4)(fp16x4_0.s0, fp16x4_0.s2, fp16x4_1.s0, fp16x4_1.s2); + acc1 += y4[5] * (float4)(fp16x4_0.s1, fp16x4_0.s3, fp16x4_1.s1, fp16x4_1.s3); + + sumf[row] += e8m0_to_fp32(xb_e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3)); + } + + yb += (N_SIMDWIDTH/2) * QK_MXFP4; + } + + global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0; + + for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) { + float sum_all = sub_group_reduce_add(sumf[row]); + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/norm.cl b/ggml/src/ggml-opencl/kernels/norm.cl index 43167ba4d22..170f822787b 100644 --- a/ggml/src/ggml-opencl/kernels/norm.cl +++ b/ggml/src/ggml-opencl/kernels/norm.cl @@ -79,3 +79,83 @@ kernel void kernel_norm( y[i00] = y[i00] * scale; } } + +//------------------------------------------------------------------------------ +// norm_mul_add +//------------------------------------------------------------------------------ +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_norm_mul_add( + global char * src0_ptr, ulong src0_offset, + global char * src1_ptr, ulong src1_offset, + global char * src2_ptr, ulong src2_offset, + global char * dst_ptr, ulong dst_offset, + int ne00, int ne01, int ne02, int ne03, + ulong nb01, ulong nb02, ulong nb03, + int ne10, int ne11, int ne12, int ne13, + ulong nb11, ulong nb12, ulong nb13, + int ne20, int ne21, int ne22, int ne23, + ulong nb21, ulong nb22, ulong nb23, + ulong nbd1, ulong nbd2, ulong nbd3, + float eps, + local float2 * sums +) { + const int i03 = get_group_id(2); + const int i02 = get_group_id(1); + const int i01 = get_group_id(0); + + global float4 * x = (global float4 *)(src0_ptr + src0_offset + i01*nb01 + i02*nb02 + i03*nb03); + global float4 * w = (global float4 *)(src1_ptr + src1_offset + (i01%ne11)*nb11 + (i02%ne12)*nb12 + (i03%ne13)*nb13); + global float4 * b = (global float4 *)(src2_ptr + src2_offset + (i01%ne21)*nb21 + (i02%ne22)*nb22 + (i03%ne23)*nb23); + global float4 * y = (global float4 *)(dst_ptr + dst_offset + i01*nbd1 + i02*nbd2 + i03*nbd3); + + float p_sum = 0.0f; + float p_sum_sq = 0.0f; + + const int n_chunks = ne00 / 4; + for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) { + float4 val = x[i00]; + p_sum += val.x + val.y + val.z + val.w; + p_sum_sq += dot(val, val); + } + + p_sum = sub_group_reduce_add(p_sum); + p_sum_sq = sub_group_reduce_add(p_sum_sq); + + if (get_sub_group_local_id() == 0) { + sums[get_sub_group_id()] = (float2)(p_sum, p_sum_sq); + } + barrier(CLK_LOCAL_MEM_FENCE); + + if (get_local_id(0) == 0) { + float sum = 0.0f; + float sum_sq = 0.0f; + for (uint i = 0; i < get_num_sub_groups(); ++i) { + float2 s = sums[i]; + sum += s.x; + sum_sq += s.y; + } + + const float inv_ne00 = 1.0f / (float)ne00; + const float mean = sum * inv_ne00; + const float variance = mad(-mean, mean, sum_sq * inv_ne00); + + sums[0] = (float2)(mean, rsqrt(variance + eps)); + } + barrier(CLK_LOCAL_MEM_FENCE); + + const float2 mean_scale = sums[0]; + const float mean = mean_scale.x; + const float scale = mean_scale.y; + const float neg_mean_scale = -mean * scale; + + for (int i00 = get_local_id(0); i00 < n_chunks; i00 += get_local_size(0)) { + const int w_idx = ne10 > 1 ? i00 : 0; + const int b_idx = ne20 > 1 ? i00 : 0; + const float4 norm_x = mad(x[i00], (float4)scale, (float4)neg_mean_scale); + y[i00] = mad(norm_x, w[w_idx], b[b_idx]); + } +} diff --git a/ggml/src/ggml-opencl/kernels/tsembd.cl b/ggml/src/ggml-opencl/kernels/tsembd.cl index 4b1107f70ba..21444bd9582 100644 --- a/ggml/src/ggml-opencl/kernels/tsembd.cl +++ b/ggml/src/ggml-opencl/kernels/tsembd.cl @@ -26,8 +26,8 @@ kernel void kernel_timestep_embedding( local_half_dim = logical_dim / 2; local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes); - if (logical_dim % 2 != 0 && local_j == ((logical_dim + 1) / 2)) { - local_embed_data_ptr[logical_dim] = 0.0f; + if (logical_dim % 2 != 0 && local_j == local_half_dim) { + local_embed_data_ptr[2 * local_half_dim] = 0.0f; } if (local_j >= local_half_dim) { diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index e84ff93efc3..dde1a5945ae 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -795,6 +795,7 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .graph_compute = */ ggml_backend_rpc_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, }; ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index 741630dba34..e0a1de0f322 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -225,9 +225,9 @@ struct bin_bcast_sycl { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * sycl::range<3>(1, 1, block_size), + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * + sycl::range<3>(1, 1, block_size), sycl::range<3>(1, 1, block_size)), [=](sycl::nd_item<3> item_ct1) { k_bin_bcast_unravel( @@ -246,8 +246,9 @@ struct bin_bcast_sycl { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, s03, s11, s12, s13, @@ -302,6 +303,10 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); } +inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); @@ -327,6 +332,11 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_sub(ctx, dst); } +void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_count_equal(ctx, dst); +} + void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_mul(ctx, dst); diff --git a/ggml/src/ggml-sycl/binbcast.hpp b/ggml/src/ggml-sycl/binbcast.hpp index 9cce0f053a5..34c4064f528 100644 --- a/ggml/src/ggml-sycl/binbcast.hpp +++ b/ggml/src/ggml-sycl/binbcast.hpp @@ -16,6 +16,12 @@ static __dpct_inline__ float op_sub(const float a, const float b) { return a - b; } +static __dpct_inline__ float op_count_equal(const float a, const float b) { + return (a == b) ? 1.0f : 0.0f; +} + +void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + static __dpct_inline__ float op_mul(const float a, const float b) { return a * b; } diff --git a/ggml/src/ggml-sycl/concat.cpp b/ggml/src/ggml-sycl/concat.cpp index 3501484a146..c7683650483 100644 --- a/ggml/src/ggml-sycl/concat.cpp +++ b/ggml/src/ggml-sycl/concat.cpp @@ -89,24 +89,33 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst, sycl::range<3> gridDim(ne2, ne1, num_blocks); switch (dim) { case 0: - sycl_parallel_for(stream, - sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); }); - break; + stream->parallel_for( + sycl::nd_range<3>(gridDim * + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); + }); + break; case 1: - sycl_parallel_for(stream, - sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); }); - break; + stream->parallel_for( + sycl::nd_range<3>(gridDim * + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); + }); + break; // dim >=2 will be dispatched to the default path default: - sycl_parallel_for(stream, - sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); }); - break; + stream->parallel_for( + sycl::nd_range<3>(gridDim * + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); + }); + break; } } @@ -120,7 +129,7 @@ static void concat_f32_sycl_non_cont( int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2, uint64_t nb3, int32_t dim) { sycl::range<3> gridDim(ne3, ne2, ne1); - sycl_parallel_for(stream, sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { + stream->parallel_for(sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { int64_t i3 = item_ct1.get_group(0); int64_t i2 = item_ct1.get_group(1); int64_t i1 = item_ct1.get_group(2); diff --git a/ggml/src/ggml-sycl/conv.cpp b/ggml/src/ggml-sycl/conv.cpp index c2f991e8d64..475bd34a25d 100644 --- a/ggml/src/ggml-sycl/conv.cpp +++ b/ggml/src/ggml-sycl/conv.cpp @@ -59,10 +59,16 @@ static void conv_transpose_1d_f32_f32_sycl( const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE; const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE); const sycl::range<3> block_nums(1, 1, num_blocks); - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - conv_transpose_1d_kernel(s0, output_size, src0_ne0, src0_ne1, src0_ne2, src1_ne0, dst_ne0, src0, src1, dst, - item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>( + block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + conv_transpose_1d_kernel( + s0, output_size, + src0_ne0, src0_ne1, src0_ne2, + src1_ne0, dst_ne0, + src0, src1, dst, item_ct1); + }); } void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 0ef567122dd..96d2583b13b 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -33,11 +33,14 @@ static void dequantize_block_sycl(const void *__restrict__ vx, { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block(vx, y, k, item_ct1); }); + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block(vx, y, k, item_ct1); + }); } } @@ -50,18 +53,24 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 64), + sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q2_K(vx, y, item_ct1); + }); } #else { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q2_K(vx, y, item_ct1); + }); } #endif @@ -76,18 +85,24 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 64), + sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q3_K(vx, y, item_ct1); + }); } #else { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q3_K(vx, y, item_ct1); + }); } #endif } @@ -101,9 +116,12 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_0(vx, y, nb32, item_ct1); }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q4_0(vx, y, nb32, item_ct1); + }); } } @@ -117,12 +135,13 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int int constexpr WARP_K = WARP_SIZE * QK4_0; const int n_warp = (k + WARP_K - 1) / WARP_K; GGML_ASSERT(k % 2 == 0); - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * sycl::range<3>(1, 1, WARP_SIZE), - sycl::range<3>(1, 1, WARP_SIZE)), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - dequantize_block_q4_0_reorder(vx, y, k, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * + sycl::range<3>(1, 1, WARP_SIZE), + sycl::range<3>(1, 1, WARP_SIZE)), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{ + dequantize_block_q4_0_reorder(vx, y, k, item_ct1); + }); + } template @@ -134,9 +153,12 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_1(vx, y, nb32, item_ct1); }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q4_1(vx, y, nb32, item_ct1); + }); } } @@ -149,13 +171,14 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor scale_local_acc(sycl::range<1>(12), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { - dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1); - }); + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1); + }); }); } } @@ -168,13 +191,13 @@ static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const i dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler & cgh) { sycl::local_accessor scale_local_acc(sycl::range<1>(12), cgh); - sycl_parallel_for<1>(cgh, sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)), - [=](sycl::nd_item<1> item_ct1) { - dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb); - }); + cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)), + [=](sycl::nd_item<1> item_ct1) { + dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb); + }); }); } @@ -187,18 +210,24 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 64), + sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q5_K(vx, y, item_ct1); + }); } #else { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q5_K(vx, y, item_ct1); + }); } #endif @@ -213,18 +242,24 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 64), + sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q6_K(vx, y, item_ct1); + }); } #else { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q6_K(vx, y, item_ct1); + }); } #endif @@ -236,9 +271,9 @@ static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const i dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); }); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); }); } template @@ -249,10 +284,15 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for( - cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_s(vx, y, item_ct1, iq1s_grid_gpu); }); + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq1_s( + vx, y, item_ct1, iq1s_grid_gpu + ); + }); }); } } @@ -265,10 +305,15 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for( - cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_m(vx, y, item_ct1, iq1s_grid_gpu); }); + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq1_m( + vx, y, item_ct1, iq1s_grid_gpu + ); + }); }); } } @@ -281,12 +326,15 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for( - cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { - dequantize_block_iq2_xxs(vx, y, item_ct1, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs); - }); + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq2_xxs( + vx, y, item_ct1, iq2xxs_grid, + ksigns_iq2xs, kmask_iq2xs); + }); }); } } @@ -299,12 +347,15 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for( - cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { - dequantize_block_iq2_xs(vx, y, item_ct1, iq2xs_grid, ksigns_iq2xs, kmask_iq2xs); - }); + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq2_xs( + vx, y, item_ct1, iq2xs_grid, + ksigns_iq2xs, kmask_iq2xs); + }); }); } } @@ -317,10 +368,13 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for( - cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq2_s(vx, y, item_ct1); }); + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq2_s(vx, y, item_ct1); + }); }); } } @@ -334,12 +388,15 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for( - cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { - dequantize_block_iq3_xxs(vx, y, item_ct1, iq3xxs_grid, ksigns_iq2xs, kmask_iq2xs); - }); + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq3_xxs( + vx, y, item_ct1, iq3xxs_grid, + ksigns_iq2xs, kmask_iq2xs); + }); }); } } @@ -352,10 +409,14 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for( - cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq3_s(vx, y, item_ct1, kmask_iq2xs, iq3s_grid); }); + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq3_s( + vx, y, item_ct1, kmask_iq2xs, iq3s_grid); + }); }); } } @@ -371,11 +432,14 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for( - cgh, - sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_xs(vx, y, item_ct1); }); + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq4_xs(vx, y, item_ct1); + }); }); } #endif @@ -389,11 +453,14 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for( - cgh, - sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_nl(vx, y, item_ct1); }); + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq4_nl(vx, y, item_ct1); + }); }); } } diff --git a/ggml/src/ggml-sycl/cpy.cpp b/ggml/src/ggml-sycl/cpy.cpp index 3d321b58ac6..1ec99b0a5d1 100644 --- a/ggml/src/ggml-sycl/cpy.cpp +++ b/ggml/src/ggml-sycl/cpy.cpp @@ -201,8 +201,7 @@ static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, co { dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - sycl_parallel_for( - stream, + stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { @@ -220,8 +219,7 @@ static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, co { dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - sycl_parallel_for( - stream, + stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { @@ -239,8 +237,7 @@ static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, co { dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - sycl_parallel_for( - stream, + stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { @@ -256,11 +253,11 @@ static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, c const int nb12, const int nb13, queue_ptr stream) { GGML_ASSERT(ne % QK8_0 == 0); const int num_blocks = ne / QK8_0; - sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); } static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, @@ -268,11 +265,11 @@ static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, c const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, queue_ptr stream) { const int num_blocks = ne; - sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { - cpy_q_f32(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_q_f32(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); } static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, @@ -281,11 +278,11 @@ static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, c const int nb12, const int nb13, queue_ptr stream) { GGML_ASSERT(ne % QK4_0 == 0); const int num_blocks = ne / QK4_0; - sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); } static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, @@ -293,9 +290,8 @@ static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, c const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, queue_ptr stream) { const int num_blocks = ne; - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { cpy_q_f32, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); @@ -308,11 +304,11 @@ static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, c const int nb12, const int nb13, queue_ptr stream) { GGML_ASSERT(ne % QK4_1 == 0); const int num_blocks = ne / QK4_1; - sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); } static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, @@ -320,9 +316,8 @@ static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, c const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, queue_ptr stream) { const int num_blocks = ne; - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { cpy_q_f32, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); @@ -335,11 +330,11 @@ static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, c const int nb12, const int nb13, queue_ptr stream) { GGML_ASSERT(ne % QK5_0 == 0); const int num_blocks = ne / QK5_0; - sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); } static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, @@ -347,9 +342,8 @@ static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, c const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, queue_ptr stream) { const int num_blocks = ne; - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { cpy_q_f32, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); @@ -362,11 +356,11 @@ static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, c const int nb12, const int nb13, queue_ptr stream) { GGML_ASSERT(ne % QK5_1 == 0); const int num_blocks = ne / QK5_1; - sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), + [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); } static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, @@ -374,9 +368,8 @@ static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, c const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, queue_ptr stream) { const int num_blocks = ne; - sycl_parallel_for( - stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { cpy_q_f32, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); @@ -389,11 +382,11 @@ static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne, const int nb12, const int nb13, queue_ptr stream) { GGML_ASSERT(ne % QK4_NL == 0); const int num_blocks = ne / QK4_NL; - sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), - [=](sycl::nd_item<3> item_ct1) { - cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { + cpy_f32_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, + ne12, nb10, nb11, nb12, nb13, item_ct1); + }); } static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, @@ -404,8 +397,7 @@ static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, co { dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - sycl_parallel_for( - stream, + stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { @@ -424,8 +416,7 @@ static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, co // dpct::has_capability_or_fail(stream->get_device(), // {sycl::aspect::fp16}); - sycl_parallel_for( - stream, + stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { @@ -444,8 +435,7 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co // dpct::has_capability_or_fail(stream->get_device(), // {sycl::aspect::fp16}); - sycl_parallel_for( - stream, + stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { @@ -460,13 +450,11 @@ static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, queue_ptr stream) { const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, - ne12, nb10, nb11, nb12, nb13, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); } @@ -475,13 +463,11 @@ static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, queue_ptr stream) { const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, - ne12, nb10, nb11, nb12, nb13, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); } @@ -491,13 +477,11 @@ static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const const int nb12, const int nb13, queue_ptr stream) { const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, - ne12, nb10, nb11, nb12, nb13, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); } @@ -506,13 +490,10 @@ static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, queue_ptr stream) { const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, - ne12, nb10, nb11, nb12, nb13, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); } @@ -522,13 +503,10 @@ static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const const int nb12, const int nb13, queue_ptr stream) { const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, - ne12, nb10, nb11, nb12, nb13, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); } void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try { diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 70579c0c3be..4f2760110c2 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -208,10 +208,12 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, nrows, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, + nrows, item_ct1); + }); } } @@ -875,11 +877,12 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - dequantize_mul_mat_vec_reorder(vx, y, dst, ncols, - nrows, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec_reorder( + vx, y, dst, ncols, nrows, item_ct1); + }); } } @@ -897,10 +900,12 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - dequantize_mul_mat_vec(vx, y, dst, ncols, nrows, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec( + vx, y, dst, ncols, nrows, item_ct1); + }); } } @@ -916,10 +921,12 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - dequantize_mul_mat_vec(vx, y, dst, ncols, nrows, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec( + vx, y, dst, ncols, nrows, item_ct1); + }); } } @@ -935,10 +942,12 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - dequantize_mul_mat_vec(vx, y, dst, ncols, nrows, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec( + vx, y, dst, ncols, nrows, item_ct1); + }); } } @@ -954,10 +963,12 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - dequantize_mul_mat_vec(vx, y, dst, ncols, nrows, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec( + vx, y, dst, ncols, nrows, item_ct1); + }); } } @@ -973,10 +984,12 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - dequantize_mul_mat_vec(vx, y, dst, ncols, nrows, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec( + vx, y, dst, ncols, nrows, item_ct1); + }); } } @@ -989,10 +1002,11 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y, const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { - dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1); + }); } static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y, @@ -1004,10 +1018,11 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y, const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { - dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1); + }); } static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y, @@ -1019,10 +1034,11 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y, const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { - dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1); + }); } static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y, @@ -1031,10 +1047,11 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE); - sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { - dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1); + }); } static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y, @@ -1046,10 +1063,11 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y, const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { - dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1); + }); } void ggml_sycl_op_dequantize_mul_mat_vec( diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index 27c72786078..d538965b096 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -13,10 +13,10 @@ #ifndef GGML_SYCL_DPCT_HELPER_HPP #define GGML_SYCL_DPCT_HELPER_HPP -#include #include #include #include +#include #ifdef GGML_SYCL_USE_INTEL_ONEMKL #include @@ -118,36 +118,6 @@ inline auto get_onemath_backend(sycl::queue& queue) #endif } -#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS - namespace syclex = sycl::ext::oneapi::experimental; -#endif - -template -__dpct_inline__ void sycl_parallel_for(sycl::handler & cgh, sycl::nd_range nd_range, Func && func) { -#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS - syclex::nd_launch(cgh, nd_range, func); -#else - cgh.parallel_for(nd_range, func); -#endif -} - -template -__dpct_inline__ void sycl_parallel_for(sycl::queue * q, sycl::nd_range nd_range, Func && func) { -#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS - syclex::nd_launch(*q, nd_range, func); -#else - q->parallel_for(nd_range, func); -#endif -} - -template __dpct_inline__ void sycl_launch(sycl::queue * stream, Func && func) { -#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS - syclex::submit(*stream, func); -#else - stream->submit(func); -#endif -} - namespace dpct { typedef sycl::queue *queue_ptr; diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 0363b06a3ec..c2da2fb48ad 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -407,7 +407,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst, const int ne12, const int nb1, const int nb2, const int offset, queue_ptr stream) { int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ACC_BLOCK_SIZE), sycl::range<1>(SYCL_ACC_BLOCK_SIZE)), @@ -425,8 +425,8 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, int dst_size = ne10 * ne11 * ne12 * ne13; int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE); sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); - sycl_parallel_for<1>( - stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + stream->parallel_for( + sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); }); } @@ -437,7 +437,7 @@ static void pad_sycl(const T *x, T *dst, const int ne00, const int ne1, const int ne2, queue_ptr stream) { int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE); sycl::range<3> gridDim(ne2, ne1, num_blocks); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); }); @@ -639,7 +639,7 @@ static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, 256); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), sycl::range<1>(256)), [=](sycl::nd_item<1> item_ct1) { @@ -652,7 +652,7 @@ static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, 256); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), sycl::range<1>(256)), [=](sycl::nd_item<1> item_ct1) { @@ -665,7 +665,7 @@ static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, 256); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), sycl::range<1>(256)), [=](sycl::nd_item<1> item_ct1) { @@ -678,7 +678,7 @@ static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tenso ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -691,7 +691,7 @@ static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tenso ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -704,7 +704,7 @@ static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -717,7 +717,7 @@ static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_t ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -730,7 +730,7 @@ static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tenso ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE), sycl::range<1>(SYCL_TANH_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -743,7 +743,7 @@ static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tenso ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -756,7 +756,7 @@ static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggm ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE), sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -769,7 +769,7 @@ static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE), sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -782,7 +782,7 @@ static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE), sycl::range<1>(SYCL_EXP_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -795,7 +795,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE), sycl::range<1>(SYCL_EXP_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -808,7 +808,7 @@ static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE), sycl::range<1>(SYCL_NEG_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -821,7 +821,7 @@ static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tenso ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE), sycl::range<1>(SYCL_NEG_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -834,7 +834,7 @@ static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_te ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE), sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -847,7 +847,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE), sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -860,7 +860,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -873,7 +873,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -888,7 +888,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) { const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -901,7 +901,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE), sycl::range<1>(SYCL_SQR_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -935,7 +935,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) { const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE); - sycl_parallel_for(stream, + stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE), sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -967,7 +967,7 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); - sycl_parallel_for(main_stream, + main_stream->parallel_for( sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); @@ -978,7 +978,7 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu - sycl_parallel_for(main_stream, + main_stream->parallel_for( sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); @@ -989,7 +989,7 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu - sycl_parallel_for(main_stream, + main_stream->parallel_for( sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); @@ -1000,7 +1000,7 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); - sycl_parallel_for(main_stream, + main_stream->parallel_for( sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); @@ -1011,7 +1011,7 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); - sycl_parallel_for(main_stream, + main_stream->parallel_for( sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp index 9c76ffeb950..03f8dd90748 100644 --- a/ggml/src/ggml-sycl/getrows.cpp +++ b/ggml/src/ggml-sycl/getrows.cpp @@ -118,10 +118,12 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr GGML_ASSERT(ne00 % 2 == 0); - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - k_get_rows(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12, - item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_get_rows( + src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, + s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); + }); GGML_UNUSED(dst); GGML_UNUSED(ctx); @@ -154,8 +156,9 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_parallel_for( - stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); }); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index a0a650e92e4..78853eb6767 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1746,12 +1746,13 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, const size_t shared_mem = ncols_pad * sizeof(int); if (order == GGML_SORT_ORDER_ASC) { - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor dpct_local_acc_ct1( sycl::range<1>(shared_mem), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { k_argsort_f32_i32( x, dst, ncols, ncols_pad, item_ct1, dpct_local_acc_ct1.get_multi_ptr() @@ -1759,12 +1760,13 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, }); }); } else if (order == GGML_SORT_ORDER_DESC) { - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor dpct_local_acc_ct1( sycl::range<1>(shared_mem), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { k_argsort_f32_i32( x, dst, ncols, ncols_pad, item_ct1, dpct_local_acc_ct1.get_multi_ptr() @@ -1782,47 +1784,50 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols, const sycl::range<3> block_nums(1, nrows, 1); const size_t shared_mem = 256 * sizeof(float); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor shared_data( sycl::range<1>(shared_mem/sizeof(float)), cgh); sycl::local_accessor shared_indices( sycl::range<1>(shared_mem/sizeof(float)), cgh); - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - const int tid = item_ct1.get_local_id(2); - const int row = item_ct1.get_global_id(1); - - float max_val = -INFINITY; - int max_idx = -1; - - for (int col = tid; col < ncols; col += 256) { - float val = x[row * ncols + col]; - if (val > max_val) { - max_val = val; - max_idx = col; + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + const int tid = item_ct1.get_local_id(2); + const int row = item_ct1.get_global_id(1); + + float max_val = -INFINITY; + int max_idx = -1; + + for (int col = tid; col < ncols; col += 256) { + float val = x[row * ncols + col]; + if (val > max_val) { + max_val = val; + max_idx = col; + } } - } - shared_data[tid] = max_val; - shared_indices[tid] = max_idx; - item_ct1.barrier(sycl::access::fence_space::local_space); + shared_data[tid] = max_val; + shared_indices[tid] = max_idx; + item_ct1.barrier(sycl::access::fence_space::local_space); - for (int stride = 256 / 2; stride > 0; stride >>= 1) { - if (tid < stride) { - float val1 = shared_data[tid]; - float val2 = shared_data[tid + stride]; - if (val2 > val1) { - shared_data[tid] = val2; - shared_indices[tid] = shared_indices[tid + stride]; + for (int stride = 256/2; stride > 0; stride >>= 1) { + if (tid < stride) { + float val1 = shared_data[tid]; + float val2 = shared_data[tid + stride]; + if (val2 > val1) { + shared_data[tid] = val2; + shared_indices[tid] = shared_indices[tid + stride]; + } } + item_ct1.barrier(sycl::access::fence_space::local_space); } - item_ct1.barrier(sycl::access::fence_space::local_space); - } - if (tid == 0) { - dst[row] = shared_indices[0]; - } - }); + + if (tid == 0) { + dst[row] = shared_indices[0]; + } + }); }); } static void diag_mask_inf_f32_sycl(const float *x, float *dst, @@ -2895,7 +2900,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons void ** ptrs_dst_get = ptrs_dst.get(); size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); - sycl_parallel_for(cgh, sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); }); @@ -3403,7 +3408,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, { sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor src1_row_acc(cgh); char *__restrict src1_contiguous_get = @@ -3415,8 +3420,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, size_t ids_nb_ct6 = ids->nb[1]; size_t ids_nb_ct7 = ids->nb[0]; - sycl_parallel_for( - cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { k_copy_src1_to_contiguous( src1_original, src1_contiguous_get, dev_cur_src1_row_get, @@ -3447,14 +3453,15 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, { sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size)); sycl::range<3> grid_dims(1, 1, num_src1_rows); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { const char *__restrict dst_contiguous_get = dst_contiguous.get(); const mmid_row_mapping *__restrict dev_row_mapping_get = dev_row_mapping.get(); - sycl_parallel_for( - cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { k_copy_dst_from_contiguous(dst_original, dst_contiguous_get, dev_row_mapping_get, @@ -3570,6 +3577,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_SUB: ggml_sycl_sub(ctx, dst); break; + case GGML_OP_COUNT_EQUAL: + ggml_sycl_count_equal(ctx, dst); + break; case GGML_OP_ACC: ggml_sycl_acc(ctx, dst); break; @@ -4063,6 +4073,7 @@ static ggml_backend_i ggml_backend_sycl_interface = { /* .graph_compute = */ ggml_backend_sycl_graph_compute, /* .event_record = */ ggml_backend_sycl_event_record, /* .event_wait = */ ggml_backend_sycl_event_wait, + /* .graph_optimize = */ NULL, }; static ggml_guid_t ggml_backend_sycl_guid() { @@ -4348,6 +4359,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ADD: case GGML_OP_ADD1: case GGML_OP_SUB: + case GGML_OP_COUNT_EQUAL: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_REPEAT: @@ -4364,11 +4376,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type); #endif case GGML_OP_NORM: - case GGML_OP_RMS_NORM: return true; case GGML_OP_L2_NORM: case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); + case GGML_OP_RMS_NORM: + return ((op->src[0]->ne[0] % WARP_SIZE) == 0); case GGML_OP_SCALE: return true; case GGML_OP_CONT: @@ -4391,12 +4404,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; - case GGML_OP_POOL_2D: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: + return ggml_is_contiguous(op->src[0]); + case GGML_OP_POOL_2D: case GGML_OP_ACC: + return true; case GGML_OP_PAD: + return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && + (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_RWKV_WKV6: diff --git a/ggml/src/ggml-sycl/gla.cpp b/ggml/src/ggml-sycl/gla.cpp index b40cbf1f14f..879184fdd31 100644 --- a/ggml/src/ggml-sycl/gla.cpp +++ b/ggml/src/ggml-sycl/gla.cpp @@ -11,13 +11,13 @@ static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B, const u_int n_seq_tokens = T / B; sycl::range<1> block_dims((C / H)); sycl::range<1> grid_dims((B * H)); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler & cgh) { /* local memory accessors*/ auto _k = sycl::local_accessor(sycl::range<1>(head_size), cgh); auto _r = sycl::local_accessor(sycl::range<1>(head_size), cgh); auto _td = sycl::local_accessor(sycl::range<1>(head_size), cgh); - sycl_parallel_for<1>(cgh, sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) { + cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) { u_int tid = item.get_local_id(0); u_int bid = item.get_group(0); diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp index 7adcb3d9d9c..6d75d34d83f 100644 --- a/ggml/src/ggml-sycl/im2col.cpp +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -70,7 +70,7 @@ static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t I const int64_t CHW = IC * KH * KW; - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { + stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1, p0, p1, d0, d1, item_ct1); }); diff --git a/ggml/src/ggml-sycl/mmq.cpp b/ggml/src/ggml-sycl/mmq.cpp index c72fcd38ebe..ffb272aa283 100644 --- a/ggml/src/ggml-sycl/mmq.cpp +++ b/ggml/src/ggml-sycl/mmq.cpp @@ -1818,7 +1818,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_qs_q4_0_acc_ct1( sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_d_q4_0_acc_ct1( @@ -1829,8 +1829,9 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q4_0( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -1852,7 +1853,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_qs_q4_0_acc_ct1( sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_d_q4_0_acc_ct1( @@ -1863,8 +1864,9 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q4_0( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -1931,7 +1933,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_qs_q4_1_acc_ct1( sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh); sycl::local_accessor tile_x_dm_q4_1_acc_ct1( @@ -1942,8 +1944,9 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q4_1( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -1965,7 +1968,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_qs_q4_1_acc_ct1( sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh); sycl::local_accessor tile_x_dm_q4_1_acc_ct1( @@ -1976,8 +1979,9 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q4_1( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2044,7 +2048,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_q5_0_acc_ct1( sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_d_q5_0_acc_ct1( @@ -2055,8 +2059,9 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q5_0( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2078,7 +2083,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_q5_0_acc_ct1( sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_d_q5_0_acc_ct1( @@ -2089,8 +2094,9 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q5_0( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2157,7 +2163,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_q5_1_acc_ct1( sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_dm_q5_1_acc_ct1( @@ -2168,8 +2174,9 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q5_1( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2191,7 +2198,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_q5_1_acc_ct1( sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_dm_q5_1_acc_ct1( @@ -2202,8 +2209,9 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q5_1( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2270,7 +2278,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_qs_q8_0_acc_ct1( sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_d_q8_0_acc_ct1( @@ -2281,8 +2289,9 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q8_0( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2304,7 +2313,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_qs_q8_0_acc_ct1( sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_d_q8_0_acc_ct1( @@ -2315,8 +2324,9 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q8_0( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2383,7 +2393,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_q2_K_acc_ct1( sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_dm_q2_K_acc_ct1( @@ -2396,8 +2406,9 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q2_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2420,7 +2431,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_q2_K_acc_ct1( sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_dm_q2_K_acc_ct1( @@ -2433,8 +2444,9 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q2_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2504,7 +2516,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_q3_K_acc_ct1( sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_dm_q3_K_acc_ct1( @@ -2519,8 +2531,9 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q3_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2544,7 +2557,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_q3_K_acc_ct1( sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_dm_q3_K_acc_ct1( @@ -2559,8 +2572,9 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q3_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2630,7 +2644,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_q4_K_acc_ct1( sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_dm_q4_K_acc_ct1( @@ -2643,8 +2657,9 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q4_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2667,7 +2682,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_q4_K_acc_ct1( sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_dm_q4_K_acc_ct1( @@ -2680,8 +2695,9 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q4_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2749,7 +2765,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_q5_K_acc_ct1( sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_dm_q5_K_acc_ct1( @@ -2762,8 +2778,9 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q5_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2786,7 +2803,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_q5_K_acc_ct1( sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_dm_q5_K_acc_ct1( @@ -2799,8 +2816,9 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q5_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2868,7 +2886,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_acc_ct1( sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_dm_acc_ct1( @@ -2881,8 +2899,9 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q6_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, @@ -2905,7 +2924,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor tile_x_ql_acc_ct1( sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::local_accessor tile_x_dm_acc_ct1( @@ -2918,8 +2937,9 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy, sycl::local_accessor tile_y_ds_acc_ct1( sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { mul_mat_q6_K( vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, item_ct1, diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index c21929d51e9..5b7f0640749 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -544,12 +544,12 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size), - [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, - nd_item); - }); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, + nd_item); + }); }); } @@ -561,12 +561,12 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -580,12 +580,17 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -599,12 +604,17 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -618,12 +628,17 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -637,12 +652,17 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -656,12 +676,17 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -675,12 +700,17 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -694,12 +724,17 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -715,12 +750,12 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size), - [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, - nd_item); - }); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, + nrows, nd_item); + }); }); } @@ -734,12 +769,17 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -754,12 +794,12 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size), - [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, - nd_item); - }); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, + nd_item); + }); }); } static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, @@ -771,12 +811,17 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -791,12 +836,14 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q_iq2_xxs_q8_1(vx, vy, dst, ncols, - nrows, item_ct1); - }); + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq2_xxs_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -810,12 +857,14 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q_iq2_xs_q8_1(vx, vy, dst, ncols, - nrows, item_ct1); - }); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq2_xs_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -829,12 +878,15 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q_iq2_s_q8_1(vx, vy, dst, ncols, nrows, - item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq2_s_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -848,12 +900,15 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q_iq3_xxs_q8_1(vx, vy, dst, ncols, - nrows, item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq3_xxs_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -867,12 +922,15 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q_iq3_s_q8_1(vx, vy, dst, ncols, nrows, - item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq3_s_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -886,12 +944,15 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q_iq1_s_q8_1(vx, vy, dst, ncols, nrows, - item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq1_s_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -905,12 +966,14 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q_iq1_m_q8_1(vx, vy, dst, ncols, nrows, - item_ct1); - }); + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq1_m_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -924,12 +987,15 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q_iq4_nl_q8_1(vx, vy, dst, ncols, nrows, - item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq4_nl_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -943,12 +1009,15 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); { - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q_iq4_xs_q8_1(vx, vy, dst, ncols, - nrows, item_ct1); - }); + + stream->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq4_xs_q8_1( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 79d846b41a1..4ec1416849c 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -254,13 +254,14 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, - nullptr, WARP_SIZE); - }); - }); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(global_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE); + }); + }); } else { const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; @@ -271,15 +272,16 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler& cgh) { sycl::local_accessor s_sum_acc_ct1( sycl::range<1>(work_group_size / WARP_SIZE), cgh); - sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, - get_pointer(s_sum_acc_ct1), work_group_size); - }); - }); + cgh.parallel_for( + sycl::nd_range<3>(global_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); + }); + }); } } @@ -288,14 +290,18 @@ static void group_norm_f32_sycl(const float* x, float* dst, const int ne_elements, queue_ptr stream, int device) { if (group_size < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler& cgh) { const float eps_ct4 = eps; - sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1, nullptr, - WARP_SIZE); - }); - }); + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + group_norm_f32( + x, dst, group_size, ne_elements, eps_ct4, item_ct1, + nullptr, WARP_SIZE); + }); + }); } else { const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; @@ -307,18 +313,22 @@ static void group_norm_f32_sycl(const float* x, float* dst, info::device::max_work_group_size. Adjust the work-group size if needed. */ - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler& cgh) { sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), cgh); const float eps_ct4 = eps; - sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1, - get_pointer(s_sum_acc_ct1), work_group_size); - }); - }); + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + group_norm_f32(x, dst, group_size, ne_elements, + eps_ct4, item_ct1, + get_pointer(s_sum_acc_ct1), work_group_size); + }); + }); } } @@ -330,13 +340,14 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const const sycl::range<3> global_dims(nsamples, nchannels, nrows); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, - nullptr, WARP_SIZE); - }); - }); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(global_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE); + }); + }); } else { const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; @@ -347,15 +358,16 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler& cgh) { sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), cgh); - sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, - get_pointer(s_sum_acc_ct1), work_group_size); - }); - }); + cgh.parallel_for( + sycl::nd_range<3>(global_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); + }); + }); } } @@ -366,12 +378,16 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); - sycl_launch(stream, [&](sycl::handler & cgh) { - sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - l2_norm_f32(x, dst, ncols, eps, item_ct1, nullptr, WARP_SIZE); - }); - }); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + l2_norm_f32(x, dst, ncols, eps, item_ct1, + nullptr, WARP_SIZE); + }); + }); } else { const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; @@ -382,15 +398,18 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler& cgh) { sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), cgh); - sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - l2_norm_f32(x, dst, ncols, eps, item_ct1, get_pointer(s_sum_acc_ct1), - work_group_size); - }); - }); + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + l2_norm_f32(x, dst, ncols, eps, item_ct1, + get_pointer(s_sum_acc_ct1), work_group_size); + }); + }); } } diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index 1b60226dcd5..a3ab703d1f0 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -232,22 +232,20 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); } else { /* DPCT1049:41: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); } } @@ -266,17 +264,15 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); if (freq_factors == nullptr) { - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); } else { - sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); } } @@ -299,12 +295,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, } // launch kernel if (freq_factors == nullptr) { - sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, item_ct1); }); } else { - sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, item_ct1); }); @@ -334,12 +330,12 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, } // launch kernel if (freq_factors == nullptr) { - sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, item_ct1); }); } else { - sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, item_ct1); }); diff --git a/ggml/src/ggml-sycl/set_rows.cpp b/ggml/src/ggml-sycl/set_rows.cpp index 7a8e1410b70..fbe15ffdd77 100644 --- a/ggml/src/ggml-sycl/set_rows.cpp +++ b/ggml/src/ggml-sycl/set_rows.cpp @@ -48,7 +48,7 @@ static void set_rows_sycl_q(const char * __restrict__ src0_d, constexpr int block_size = 256; const int64_t grid_size = ceil_div(total_blocks, block_size); - sycl_parallel_for(stream, sycl::nd_range<1>(grid_size * block_size, block_size), [=](sycl::nd_item<1> item_ct1) { + stream->parallel_for(sycl::nd_range<1>(grid_size * block_size, block_size), [=](sycl::nd_item<1> item_ct1) { const int64_t i = item_ct1.get_global_linear_id(); if (i >= total_blocks) { return; @@ -129,8 +129,7 @@ static void set_rows_sycl( constexpr int block_size = 64; const int64_t grid_size = ceil_div(total_elements, block_size); - sycl_parallel_for( - stream, + stream->parallel_for( sycl::nd_range<1>(grid_size * block_size, block_size), [=](sycl::nd_item<1> item_ct1) { k_set_rows( diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index 7b60c292e0c..52fcf4b3dbd 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -127,11 +127,11 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, const size_t n_local_scratch, queue_ptr stream) { - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler &cgh) { sycl::local_accessor local_buf_acc(n_local_scratch, cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { soft_max_f32(x, mask, dst, ncols_par, nrows_y, scale, max_bias, m0, diff --git a/ggml/src/ggml-sycl/tsembd.cpp b/ggml/src/ggml-sycl/tsembd.cpp index 721c8fa6fa2..f2003794d3f 100644 --- a/ggml/src/ggml-sycl/tsembd.cpp +++ b/ggml/src/ggml-sycl/tsembd.cpp @@ -21,11 +21,12 @@ static void timestep_embedding_f32( int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); float * embed_data = (float *)((char *)dst + i*nb1); - if (dim % 2 != 0 && j == ((dim + 1) / 2)) { - embed_data[dim] = 0.f; + int half = dim / 2; + + if (dim % 2 != 0 && j == half) { + embed_data[2 * half] = 0.f; } - int half = dim / 2; if (j >= half) { return; } @@ -45,9 +46,14 @@ static void timestep_embedding_f32_sycl( int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE; sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE); sycl::range<3> gridDim(1, ne00, num_blocks); - sycl_parallel_for(stream, sycl::nd_range<3>(gridDim * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - timestep_embedding_f32(x, dst, nb1, dim, max_period, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>( + gridDim * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + timestep_embedding_f32( + x, dst, nb1, dim, max_period, item_ct1 + ); + }); } void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-sycl/wkv.cpp b/ggml/src/ggml-sycl/wkv.cpp index 3ed5bbf355a..c10e2f7645e 100644 --- a/ggml/src/ggml-sycl/wkv.cpp +++ b/ggml/src/ggml-sycl/wkv.cpp @@ -207,11 +207,12 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { // Submit kernel if (C / H == WKV_BLOCK_SIZE) { - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler& cgh) { sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { rwkv_wkv6_f32_kernel( B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() @@ -219,11 +220,12 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { }); }); } else { - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler& cgh) { sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { rwkv_wkv6_f32_kernel( B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() @@ -262,11 +264,12 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { // Submit kernel if (C / H == WKV_BLOCK_SIZE) { - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler& cgh) { sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { rwkv_wkv7_f32_kernel( B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() @@ -274,11 +277,12 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { }); }); } else { - sycl_launch(stream, [&](sycl::handler & cgh) { + stream->submit([&](sycl::handler& cgh) { sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); - sycl_parallel_for( - cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { rwkv_wkv7_f32_kernel( B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7ef938066b1..5c941e72135 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5,8 +5,14 @@ #include "ggml-cpu.h" #endif +// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- +#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1 + #include +// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- +VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE + #include #include #include @@ -102,9 +108,9 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } struct ggml_backend_vk_context; -#define MAX_PARAMETER_COUNT 8 +#define MAX_PARAMETER_COUNT 12 // Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT. -#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 2) +#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3) struct vk_pipeline_struct { std::string name; @@ -115,10 +121,14 @@ struct vk_pipeline_struct { uint32_t parameter_count; std::array wg_denoms; uint32_t align; + // true if fields have been set by ggml_vk_create_pipeline + bool initialized {}; // set to true to request the pipeline is compiled after the dryrun bool needed {}; // set to true when the shader has been compiled bool compiled {}; + // number of registers used, extracted from pipeline executable properties + uint32_t register_count {}; }; typedef std::shared_ptr vk_pipeline; @@ -227,21 +237,6 @@ enum vk_device_architecture { NVIDIA_PRE_TURING, }; -// HSK x HSV -enum FaHeadSizes { - FA_HEAD_SIZE_64, - FA_HEAD_SIZE_80, - FA_HEAD_SIZE_96, - FA_HEAD_SIZE_112, - FA_HEAD_SIZE_128, - FA_HEAD_SIZE_192, - FA_HEAD_SIZE_192_128, - FA_HEAD_SIZE_256, - FA_HEAD_SIZE_576_512, - FA_HEAD_SIZE_UNSUPPORTED, - FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED, -}; - static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { vk::PhysicalDeviceProperties props = device.getProperties(); @@ -351,6 +346,35 @@ enum dmmv_wg_sizes { DMMV_WG_SIZE_COUNT, }; +enum FaCodePath { + FA_SCALAR, + FA_COOPMAT1, + FA_COOPMAT2, +}; + +struct vk_fa_pipeline_state { + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc) + : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {} + + uint32_t HSK, HSV; + bool small_rows; + FaCodePath path; + bool aligned; + bool f32acc; + + bool operator<(const vk_fa_pipeline_state &b) const { + return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) < + std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc); + } +}; + +enum shader_reduction_mode { + SHADER_REDUCTION_MODE_SHMEM, + SHADER_REDUCTION_MODE_HYBRID, + SHADER_REDUCTION_MODE_SUBGROUP, + SHADER_REDUCTION_MODE_COUNT, +}; + static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); @@ -377,11 +401,18 @@ struct vk_device_struct { bool uma; bool prefer_host_memory; bool float_controls_rte_fp16; - bool subgroup_add; + bool subgroup_arithmetic; bool subgroup_shuffle; + bool subgroup_ballot; + bool subgroup_clustered; bool multi_add; + bool add_rms_fusion; + uint32_t partials_binding_alignment; + bool integer_dot_product; + // 0: default, 1: force mmvq, -1: disable mmvq + int32_t mmvq_mode; bool subgroup_size_control; uint32_t subgroup_min_size; @@ -406,6 +437,8 @@ struct vk_device_struct { bool coopmat2; + bool pipeline_executable_properties_support {}; + size_t idx; bool mul_mat_l[GGML_TYPE_COUNT]; @@ -439,12 +472,15 @@ struct vk_device_struct { vk_pipeline pipeline_matmul_split_k_reduce; vk_pipeline pipeline_quantize_q8_1; + vk_pipeline pipeline_quantize_q8_1_x4; vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; @@ -460,9 +496,12 @@ struct vk_device_struct { vk_pipeline pipeline_mul_norepeat[2][2][2]; vk_pipeline pipeline_div[2][2][2]; vk_pipeline pipeline_div_norepeat[2][2][2]; + vk_pipeline pipeline_add_rms[2][2][2]; + vk_pipeline pipeline_add_rms_norepeat[2][2][2]; // indexed by num_additional_fused_ops == num_adds - 1 vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS]; + vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS]; vk_pipeline pipeline_add_id_f32; @@ -477,8 +516,8 @@ struct vk_device_struct { vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT]; @@ -486,10 +525,13 @@ struct vk_device_struct { vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; vk_pipeline pipeline_rms_norm_mul_f32; + vk_pipeline pipeline_rms_norm_partials_f32; + vk_pipeline pipeline_rms_norm_mul_partials_f32; vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_l2_norm_f32; // [src/dst 0=fp32,1=fp16] + vk_pipeline pipeline_exp[2]; vk_pipeline pipeline_gelu[2]; vk_pipeline pipeline_gelu_erf[2]; vk_pipeline pipeline_gelu_quick[2]; @@ -497,6 +539,8 @@ struct vk_device_struct { vk_pipeline pipeline_relu[2]; vk_pipeline pipeline_tanh[2]; vk_pipeline pipeline_sigmoid[2]; + vk_pipeline pipeline_hardsigmoid[2]; + vk_pipeline pipeline_hardswish[2]; vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_reglu[2]; @@ -520,6 +564,7 @@ struct vk_device_struct { vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; + vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_conv_transpose_1d_f32; vk_pipeline pipeline_pool2d_f32; @@ -529,19 +574,14 @@ struct vk_device_struct { vk_pipeline pipeline_opt_step_sgd_f32; vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; - vk_pipeline pipeline_conv2d_dw_whcn_f32; - vk_pipeline pipeline_conv2d_dw_cwhn_f32; - - // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} - vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; + vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; + vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; - vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; - - vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; + std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; vk_pipeline pipeline_flash_attn_split_k_reduce; - std::unordered_map pipelines; + std::vector all_pipelines; std::vector> pinned_memory; @@ -552,6 +592,8 @@ struct vk_device_struct { bool disable_fusion; bool disable_host_visible_vidmem; + bool allow_sysmem_fallback; + bool disable_graph_optimize; #ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; @@ -572,15 +614,15 @@ struct vk_device_struct { compute_queue.cmd_pool.destroy(device); transfer_queue.cmd_pool.destroy(device); - for (auto& pipeline : pipelines) { - if (pipeline.second.expired()) { + for (auto& pipeline : all_pipelines) { + if (pipeline.expired()) { continue; } - vk_pipeline pl = pipeline.second.lock(); + vk_pipeline pl = pipeline.lock(); ggml_vk_destroy_pipeline(device, pl); } - pipelines.clear(); + all_pipelines.clear(); device.destroyDescriptorSetLayout(dsl); @@ -773,6 +815,57 @@ static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_ten p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); + return p; // offsets are initialized later in ggml_vk_op +} + +struct vk_op_pad_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t misalign_offsets; + + uint32_t lp0; uint32_t rp0; + uint32_t lp1; uint32_t rp1; + uint32_t lp2; uint32_t rp2; + uint32_t lp3; uint32_t rp3; +}; + +static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) { + int64_t ne = ggml_nelements(dst); + GGML_ASSERT(ne <= (int64_t)std::numeric_limits::max()); + + vk_op_pad_push_constants p{}; + p.ne = (uint32_t)ne; + + size_t src0_tsize = ggml_type_size(src0->type); + p.ne00 = (uint32_t)src0->ne[0]; + p.ne01 = (uint32_t)src0->ne[1]; + p.ne02 = (uint32_t)src0->ne[2]; + p.ne03 = (uint32_t)src0->ne[3]; + p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize); + p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize); + p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize); + p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize); + + size_t dst_tsize = ggml_type_size(dst->type); + p.ne10 = (uint32_t)dst->ne[0]; + p.ne11 = (uint32_t)dst->ne[1]; + p.ne12 = (uint32_t)dst->ne[2]; + p.ne13 = (uint32_t)dst->ne[3]; + p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize); + p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize); + p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); + p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); + + p.lp0 = dst->op_params[0]; + p.rp0 = dst->op_params[1]; + p.lp1 = dst->op_params[2]; + p.rp1 = dst->op_params[3]; + p.lp2 = dst->op_params[4]; + p.rp2 = dst->op_params[5]; + p.lp3 = dst->op_params[6]; + p.rp3 = dst->op_params[7]; + return p; // fastdiv values and offsets are initialized later in ggml_vk_op } @@ -822,8 +915,13 @@ struct vk_op_multi_add_push_constants { uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; // strides for srcs+dst - uint32_t nb[8][4]; + uint32_t nb[MAX_PARAMETER_COUNT][4]; + + uint32_t rms_partials; }; +// update multi_add.comp if this changes +static_assert(MAX_PARAMETER_COUNT == 12); +static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); struct vk_op_add_id_push_constants { uint32_t ne0; @@ -896,6 +994,37 @@ struct vk_op_im2col_push_constants { int32_t d0; int32_t d1; }; +struct vk_op_im2col_3d_push_constants { + uint32_t nb10; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t s0; + uint32_t s1; + uint32_t s2; + uint32_t p0; + uint32_t p1; + uint32_t p2; + uint32_t d0; + uint32_t d1; + uint32_t d2; + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t IC; + uint32_t KW; + uint32_t OH; + uint32_t KD_KH_KW; + uint32_t KH_KW; + uint32_t IC_KD_KH_KW; + uint32_t N_OD_OH; + uint32_t OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW; + uint32_t misalign_offsets; +}; + struct vk_op_timestep_embedding_push_constants { uint32_t nb1; uint32_t dim; @@ -1014,6 +1143,39 @@ struct vk_op_upscale_push_constants { float sf0; float sf1; float sf2; float sf3; }; +struct vk_op_sum_rows_push_constants +{ + uint32_t n_cols; + uint32_t ne01, ne02; + uint32_t nb01, nb02, nb03; + uint32_t nb11, nb12, nb13; + float weight; + uint32_t misalign_offsets; + uint32_t ne0_12mp, ne0_12L; + uint32_t ne0_1mp, ne0_1L; +}; + +static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) { + uint32_t type_size = (uint32_t)ggml_type_size(src->type); + vk_op_sum_rows_push_constants p = {}; + p.n_cols = (uint32_t)n_cols; + p.ne01 = (uint32_t)src->ne[1]; + p.ne02 = (uint32_t)src->ne[2]; + p.nb01 = (uint32_t)src->nb[1] / type_size; + p.nb02 = (uint32_t)src->nb[2] / type_size; + p.nb03 = (uint32_t)src->nb[3] / type_size; + p.nb11 = (uint32_t)dst->nb[1] / type_size; + p.nb12 = (uint32_t)dst->nb[2] / type_size; + p.nb13 = (uint32_t)dst->nb[3] / type_size; + p.weight = 1.0f; + return p; +} + +template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) { + init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L); + init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L); +} + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -1069,8 +1231,6 @@ static std::string format_size(size_t size) { return oss.str(); } -static std::mutex log_mutex; - class vk_memory_logger { public: vk_memory_logger(): total_device(0), total_host(0) {} @@ -1174,6 +1334,12 @@ class vk_perf_logger { timings[name].push_back(time); return; } + if (node->op == GGML_OP_RMS_NORM) { + std::string name = ggml_op_name(node->op); + name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")"; + timings[name].push_back(time); + return; + } timings[ggml_op_name(node->op)].push_back(time); } private: @@ -1188,10 +1354,25 @@ struct ggml_backend_vk_context { size_t semaphore_idx, event_idx; ggml_vk_garbage_collector gc; - size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; - vk_buffer prealloc_x, prealloc_y, prealloc_split_k; + size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset; + vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials; vk::Fence fence, almost_ready_fence; bool almost_ready_fence_pending {}; + // Set before op_add and unset after op_rms_norm to indicate that the add should + // write partial sums to accumulate the square of the vector components + bool do_add_rms_partials; + + // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert. + vk_pipeline_struct * prealloc_y_last_pipeline_used {}; + const ggml_tensor * prealloc_y_last_tensor_used {}; + + // Track which nodes have been used since the last sync, and whether they were written to + std::vector unsynced_nodes_written; + std::vector unsynced_nodes_read; + // Track which prealloc buffers have pending reads that need to be synchronized. + // These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set), + // and set to true after the buffer contents are consumed. + bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; vk_buffer buffer_pool[MAX_VK_BUFFERS]; @@ -1239,6 +1420,8 @@ struct ggml_backend_vk_buffer_context { }; #ifdef GGML_VULKAN_MEMORY_DEBUG +static std::mutex log_mutex; + void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) { std::lock_guard guard(log_mutex); vk_buffer buf = buf_ref.lock(); @@ -1283,6 +1466,7 @@ struct vk_instance_t { PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {}; std::vector device_indices; + std::vector device_supports_membudget; vk_device devices[GGML_VK_MAX_DEVICES]; }; @@ -1429,9 +1613,23 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast(duoni)); } + if (device->pipeline_executable_properties_support) { + vk::PipelineExecutableInfoKHR executableInfo; + executableInfo.pipeline = pipeline->pipeline; + + auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo); + for (auto & s : statistics) { + // "Register Count" is reported by NVIDIA drivers. + if (strcmp(s.name, "Register Count") == 0) { + VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers"); + pipeline->register_count = (uint32_t)s.value.u64; + } + } + } + { std::lock_guard guard(device->mutex); - device->pipelines.insert({ pipeline->name, pipeline }); + device->all_pipelines.push_back(pipeline); } { @@ -1735,8 +1933,8 @@ static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pr return UINT32_MAX; } -static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { - VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")"); +static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) { + VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")"); if (size > device->max_memory_allocation_size) { throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit"); } @@ -1763,42 +1961,34 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); - uint32_t memory_type_index = UINT32_MAX; + for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { + const auto & req_flags = *it; - memory_type_index = find_properties(&mem_props, &mem_req, req_flags); - buf->memory_property_flags = req_flags; + uint32_t memory_type_index = find_properties(&mem_props, &mem_req, req_flags); - if (memory_type_index == UINT32_MAX && fallback_flags) { - memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); - buf->memory_property_flags = fallback_flags; - } - - if (memory_type_index == UINT32_MAX) { - device->device.destroyBuffer(buf->buffer); - throw vk::OutOfDeviceMemoryError("No suitable memory type found"); - } - - try { - buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); - } catch (const vk::SystemError& e) { - if (buf->memory_property_flags != fallback_flags) { - // Try again with fallback flags - memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); - buf->memory_property_flags = fallback_flags; + if (memory_type_index == UINT32_MAX) { + continue; + } + buf->memory_property_flags = req_flags; - try { - buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); - } - catch (const vk::SystemError& e) { + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); + break; + } catch (const vk::SystemError& e) { + // loop and retry + // during last attempt throw the exception + if (it + 1 == req_flags_list.end()) { device->device.destroyBuffer(buf->buffer); throw e; } - } else { - // Out of Host/Device memory, clean up buffer - device->device.destroyBuffer(buf->buffer); - throw e; } } + + if (!buf->device_memory) { + device->device.destroyBuffer(buf->buffer); + throw vk::OutOfDeviceMemoryError("No suitable memory type found"); + } + buf->ptr = nullptr; if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { @@ -1819,7 +2009,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { try { - return ggml_vk_create_buffer(device, size, req_flags, fallback_flags); + return ggml_vk_create_buffer(device, size, {req_flags, fallback_flags}); } catch (const vk::SystemError& e) { std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl; std::cerr << "ggml_vulkan: " << e.what() << std::endl; @@ -1831,15 +2021,29 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { vk_buffer buf; try { if (device->prefer_host_memory) { - buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, + vk::MemoryPropertyFlagBits::eDeviceLocal}); } else if (device->uma) { // Fall back to host memory type - buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); } else if (device->disable_host_visible_vidmem) { - buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eDeviceLocal); + if (device->allow_sysmem_fallback) { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); + } else { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + } } else { // use rebar if available, otherwise fallback to device only visible memory - buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); + if (device->allow_sysmem_fallback) { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, + vk::MemoryPropertyFlagBits::eDeviceLocal, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); + } else { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, + vk::MemoryPropertyFlagBits::eDeviceLocal}); + } } } catch (const vk::SystemError& e) { std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; @@ -1868,14 +2072,18 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { return { buf, 0, VK_WHOLE_SIZE }; } -static void ggml_vk_sync_buffers(vk_context& ctx) { +static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) { VK_LOG_DEBUG("ggml_vk_sync_buffers()"); - const bool transfer_queue = ctx->p->q->transfer_only; + const bool transfer_queue = subctx->p->q->transfer_only; - ctx->s->buffer.pipelineBarrier( - ctx->p->q->stage_flags, - ctx->p->q->stage_flags, + if (ctx) { + ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; + } + + subctx->s->buffer.pipelineBarrier( + subctx->p->q->stage_flags, + subctx->p->q->stage_flags, {}, { { { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, @@ -1902,47 +2110,12 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events ); } -enum FaCodePath { - FA_SCALAR, - FA_COOPMAT1, - FA_COOPMAT2, -}; - -static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) { - if (hsk != 192 && hsk != 576 && hsk != hsv) { - return FA_HEAD_SIZE_UNSUPPORTED; - } - switch (hsk) { - case 64: return FA_HEAD_SIZE_64; - case 80: return FA_HEAD_SIZE_80; - case 96: return FA_HEAD_SIZE_96; - case 112: return FA_HEAD_SIZE_112; - case 128: return FA_HEAD_SIZE_128; - case 192: - if (hsv == 192) { - return FA_HEAD_SIZE_192; - } else if (hsv == 128) { - return FA_HEAD_SIZE_192_128; - } else { - return FA_HEAD_SIZE_UNSUPPORTED; - } - case 256: return FA_HEAD_SIZE_256; - case 576: - if (hsv == 512) { - return FA_HEAD_SIZE_576_512; - } else { - return FA_HEAD_SIZE_UNSUPPORTED; - } - default: return FA_HEAD_SIZE_UNSUPPORTED; - } -} - // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { - if (hsv >= 512) { + if (hsv >= 192) { return 2; } else { return 8; @@ -1972,7 +2145,13 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if (small_rows) { return {scalar_flash_attention_num_small_rows, 64}; } else { - return {get_fa_scalar_num_large_rows(hsv), 32}; + if ((hsv | hsk) & 8) { + // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter + // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. + return {get_fa_scalar_num_large_rows(hsv), 64}; + } else { + return {get_fa_scalar_num_large_rows(hsv), 32}; + } } } @@ -1990,8 +2169,8 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 } // small cols to reduce register count - if (ggml_is_quantized(type) || hsk >= 256) { - if (hsk >= 512) { + if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) { + if (hsk >= 512 || hsv >= 512) { return {32, 32}; } else { return {64, 32}; @@ -2000,6 +2179,10 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 return {64, 64}; } +static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) { + return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1]; +} + static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { uint32_t lut_size = 0; @@ -2038,10 +2221,11 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec const uint32_t warps = warptile[0] / warptile[10]; const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; - const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0; + const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0; const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; + const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0; - const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; + const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " @@ -2125,8 +2309,17 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); + const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; + const uint32_t mul_mat_subgroup_size_8 = std::max(mul_mat_subgroup_size, 8u); + const uint32_t mul_mat_subgroup_size_16 = std::max(mul_mat_subgroup_size, 16u); + const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u); + + const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) || + (device->subgroup_size_control && device->subgroup_max_size >= 16); + // mulmat std::vector l_warptile, m_warptile, s_warptile, + l_warptile_id, m_warptile_id, s_warptile_id, l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int, l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, @@ -2163,9 +2356,9 @@ static void ggml_vk_load_shaders(vk_device& device) { s_mmq_wg_denoms_k = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 128, 16, 0 }; - m_warptile_mmqid = { 256, 128, 64, 16, 0 }; - s_warptile_mmqid = { 256, 128, 64, 16, 0 }; + l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size }; + m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; + s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; l_mmqid_wg_denoms = { 128, 128, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; s_mmqid_wg_denoms = { 128, 64, 1 }; @@ -2197,9 +2390,18 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; + m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; + s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; + + l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 }; + m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; + s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; + // chip specific tuning if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; } l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; @@ -2225,14 +2427,14 @@ static void ggml_vk_load_shaders(vk_device& device) { } // Disable mul_mat_id if not enough shared memory is available - if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) { + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) { device->mul_mat_id_s[i] = false; device->mul_mat_id_m[i] = false; device->mul_mat_id_l[i] = false; - } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) { + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) { device->mul_mat_id_m[i] = false; device->mul_mat_id_l[i] = false; - } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) { + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) { device->mul_mat_id_l[i] = false; } } @@ -2255,7 +2457,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } std::vector> compiles; - auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, + auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { @@ -2265,11 +2467,14 @@ static void ggml_vk_load_shaders(vk_device& device) { if (!pipeline) { pipeline = std::make_shared(); + } + if (!pipeline->initialized) { pipeline->name = name; pipeline->parameter_count = parameter_count; pipeline->push_constant_size = push_constant_size; pipeline->wg_denoms = wg_denoms; pipeline->align = align; + pipeline->initialized = true; } if (!pipeline->needed || pipeline->compiled) { @@ -2289,6 +2494,14 @@ static void ggml_vk_load_shaders(vk_device& device) { parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; + auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint, + uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, + uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + return ggml_vk_create_pipeline(device, pipeline, name.c_str(), spv_size, spv_data, entrypoint, + parameter_count, push_constant_size, wg_denoms, specialization_constants, + align, disable_robustness, require_full_subgroups, required_subgroup_size); + }; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; }; @@ -2315,26 +2528,30 @@ static void ggml_vk_load_shaders(vk_device& device) { return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; }; -#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512) + for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ + uint32_t HSK = fa.first.HSK; \ + uint32_t HSV = fa.first.HSV; \ + bool small_rows = fa.first.small_rows; \ + FaCodePath path = fa.first.path; \ + bool aligned = fa.first.aligned; \ + bool f32acc = fa.first.f32acc; \ + if (path == FAPATH) { \ + if (aligned) { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } \ + } else { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } \ + } \ + } \ + } CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) @@ -2357,7 +2574,6 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) } #endif -#undef CREATE_FA2 #undef CREATE_FA #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) @@ -2404,32 +2620,34 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + GGML_ASSERT(device->subgroup_ballot); + + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) if (device->coopmat_bf16_support) { - CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) } #endif - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) #undef CREATE_MM #undef CREATE_MM2 } else @@ -2516,55 +2734,56 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + GGML_ASSERT(device->subgroup_ballot); + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) if (device->coopmat_bf16_support) { - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); } #endif - CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); #undef CREATE_MM2 #undef CREATE_MM } else #endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->fp16) { // Create 6 variants, {s,m,l}x{unaligned,aligned} -#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) { \ @@ -2581,38 +2800,38 @@ static void ggml_vk_load_shaders(vk_device& device) { } \ // Create 2 variants, {f16,f32} accumulator -#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - - CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - - CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); +#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -2624,51 +2843,77 @@ static void ggml_vk_load_shaders(vk_device& device) { } #endif - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + } else { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); + + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + } #undef CREATE_MM2 #undef CREATE_MMQ #undef CREATE_MM } else { // Create 6 variants, {s,m,l}x{unaligned,aligned} -#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) \ @@ -2678,34 +2923,34 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -2717,33 +2962,59 @@ static void ggml_vk_load_shaders(vk_device& device) { } #endif - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + } else { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + } } // reusing CREATE_MM from the fp32 path if ((device->coopmat2 || device->coopmat_support) @@ -2760,8 +3031,8 @@ static void ggml_vk_load_shaders(vk_device& device) { m_wg_denoms = { 64, 64, 1 }; s_wg_denoms = { 32, 32, 1 }; - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); } #undef CREATE_MM @@ -2779,60 +3050,89 @@ static void ggml_vk_load_shaders(vk_device& device) { rm_stdq = 2; uint32_t rm_iq = 2 * rm_kq; + const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN; + // Ensure a subgroup size >= 16 is available + const bool use_subgroups16 = use_subgroups && subgroup_min_size_16; + + const uint32_t subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16) ? 16 : device->subgroup_size; + const uint32_t subgroup_size16 = std::max(subgroup_size, 16u); + + const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0; + const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0; + for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) { - uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_16 : (subgroup_size_16 * 4); - uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? device->subgroup_size : (device->subgroup_size * 4); + const uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4); + const uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size16 : (subgroup_size16 * 4); + + const shader_reduction_mode reduc = (use_subgroups && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP : + (use_subgroups && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID : + SHADER_REDUCTION_MODE_SHMEM; - const bool s = device->subgroup_add && device->architecture != vk_device_architecture::AMD_GCN; + const shader_reduction_mode reduc16 = (use_subgroups16 && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP : + (use_subgroups16 && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID : + SHADER_REDUCTION_MODE_SHMEM; for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_f32_f32_f32_len[s], arr_dmmv_f32_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_f16_f32_f32_len[s], arr_dmmv_f16_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_bf16_f32_f32_len[s], arr_dmmv_bf16_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q4_0_f32_f32_len[s], arr_dmmv_q4_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q4_1_f32_f32_len[s], arr_dmmv_q4_1_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q5_0_f32_f32_len[s], arr_dmmv_q5_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q5_1_f32_f32_len[s], arr_dmmv_q5_1_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q8_0_f32_f32_len[s], arr_dmmv_q8_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q2_k_f32_f32_len[s], arr_dmmv_q2_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q3_k_f32_f32_len[s], arr_dmmv_q3_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q4_k_f32_f32_len[s], arr_dmmv_q4_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q5_k_f32_f32_len[s], arr_dmmv_q5_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q6_k_f32_f32_len[s], arr_dmmv_q6_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq1_s_f32_f32_len[s], arr_dmmv_iq1_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq1_m_f32_f32_len[s], arr_dmmv_iq1_m_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq2_xxs_f32_f32_len[s], arr_dmmv_iq2_xxs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq2_xs_f32_f32_len[s], arr_dmmv_iq2_xs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq2_s_f32_f32_len[s], arr_dmmv_iq2_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq3_xxs_f32_f32_len[s], arr_dmmv_iq3_xxs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq3_s_f32_f32_len[s], arr_dmmv_iq3_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq4_xs_f32_f32_len[s], arr_dmmv_iq4_xs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq4_nl_f32_f32_len[s], arr_dmmv_iq4_nl_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_mxfp4_f32_f32_len[s], arr_dmmv_mxfp4_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_f32_f16_f32_len[s], arr_dmmv_f32_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_f16_f16_f32_len[s], arr_dmmv_f16_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_bf16_f16_f32_len[s], arr_dmmv_bf16_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q4_0_f16_f32_len[s], arr_dmmv_q4_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q4_1_f16_f32_len[s], arr_dmmv_q4_1_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q5_0_f16_f32_len[s], arr_dmmv_q5_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q5_1_f16_f32_len[s], arr_dmmv_q5_1_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q8_0_f16_f32_len[s], arr_dmmv_q8_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q2_k_f16_f32_len[s], arr_dmmv_q2_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q3_k_f16_f32_len[s], arr_dmmv_q3_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q4_k_f16_f32_len[s], arr_dmmv_q4_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q5_k_f16_f32_len[s], arr_dmmv_q5_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q6_k_f16_f32_len[s], arr_dmmv_q6_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq1_s_f16_f32_len[s], arr_dmmv_iq1_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq1_m_f16_f32_len[s], arr_dmmv_iq1_m_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq2_xxs_f16_f32_len[s], arr_dmmv_iq2_xxs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq2_xs_f16_f32_len[s], arr_dmmv_iq2_xs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq2_s_f16_f32_len[s], arr_dmmv_iq2_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq3_xxs_f16_f32_len[s], arr_dmmv_iq3_xxs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq3_s_f16_f32_len[s], arr_dmmv_iq3_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq4_xs_f16_f32_len[s], arr_dmmv_iq4_xs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq4_nl_f16_f32_len[s], arr_dmmv_iq4_nl_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_mxfp4_f16_f32_len[s], arr_dmmv_mxfp4_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32", arr_dmmv_q5_k_f32_f32_len[reduc16], arr_dmmv_q5_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32", arr_dmmv_q6_k_f32_f32_len[reduc16], arr_dmmv_q6_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32", arr_dmmv_iq1_s_f32_f32_len[reduc16], arr_dmmv_iq1_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32", arr_dmmv_iq1_m_f32_f32_len[reduc16], arr_dmmv_iq1_m_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32", arr_dmmv_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_iq2_xxs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32", arr_dmmv_iq2_xs_f32_f32_len[reduc16], arr_dmmv_iq2_xs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32", arr_dmmv_iq2_s_f32_f32_len[reduc16], arr_dmmv_iq2_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32", arr_dmmv_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_iq3_xxs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32", arr_dmmv_iq3_s_f32_f32_len[reduc16], arr_dmmv_iq3_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32", arr_dmmv_q5_k_f16_f32_len[reduc16], arr_dmmv_q5_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32", arr_dmmv_q6_k_f16_f32_len[reduc16], arr_dmmv_q6_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32", arr_dmmv_iq1_s_f16_f32_len[reduc16], arr_dmmv_iq1_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32", arr_dmmv_iq1_m_f16_f32_len[reduc16], arr_dmmv_iq1_m_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32", arr_dmmv_iq2_xxs_f16_f32_len[reduc16], arr_dmmv_iq2_xxs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32", arr_dmmv_iq2_xs_f16_f32_len[reduc16], arr_dmmv_iq2_xs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32", arr_dmmv_iq2_s_f16_f32_len[reduc16], arr_dmmv_iq2_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32", arr_dmmv_iq3_xxs_f16_f32_len[reduc16], arr_dmmv_iq3_xxs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32", arr_dmmv_iq3_s_f16_f32_len[reduc16], arr_dmmv_iq3_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; + const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + } +#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } } @@ -2924,21 +3224,32 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + + if (device->subgroup_clustered && device->subgroup_require_full_support) { + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_subgroup_len, quantize_q8_1_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + } for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { - if (device->subgroup_add && device->subgroup_require_full_support) { - ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true); + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { + ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true); } else { - ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); + ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); } } ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 12 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -2947,12 +3258,16 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); @@ -3008,25 +3323,28 @@ static void ggml_vk_load_shaders(vk_device& device) { }; bool rte = device->float_controls_rte_fp16; -#define CREATE_BINARY(name, namemod, spec) \ +#define CREATE_BINARY(name, namemod, spec, bindings) \ for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ + ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ - "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); - - CREATE_BINARY(add, , {0}) - CREATE_BINARY(add, _norepeat, {1}) - CREATE_BINARY(sub, , {0}) - CREATE_BINARY(sub, _norepeat, {1}) - CREATE_BINARY(mul, , {0}) - CREATE_BINARY(mul, _norepeat, {1}) - CREATE_BINARY(div, , {0}) - CREATE_BINARY(div, _norepeat, {1}) + "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); + + CREATE_BINARY(add, , {0}, 4) + CREATE_BINARY(add, _norepeat, {1}, 4) + CREATE_BINARY(sub, , {0}, 3) + CREATE_BINARY(sub, _norepeat, {1}, 3) + CREATE_BINARY(mul, , {0}, 3) + CREATE_BINARY(mul, _norepeat, {1}, 3) + CREATE_BINARY(div, , {0}, 3) + CREATE_BINARY(div, _norepeat, {1}, 3) + CREATE_BINARY(add_rms, , {0}, 4) + CREATE_BINARY(add_rms, _norepeat, {1}, 4) #undef CREATE_BINARY if (device->multi_add) { for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) { - ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); + ggml_vk_create_pipeline2(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); + ggml_vk_create_pipeline2(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); } } @@ -3051,7 +3369,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -3062,6 +3380,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + CREATE_UNARY(exp) CREATE_UNARY(gelu) CREATE_UNARY(gelu_erf) CREATE_UNARY(gelu_quick) @@ -3069,6 +3388,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(relu) CREATE_UNARY(tanh) CREATE_UNARY(sigmoid) + CREATE_UNARY(hardsigmoid) + CREATE_UNARY(hardswish) #undef CREATE_UNARY #define CREATE_GLU(name) \ @@ -3097,7 +3418,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -3117,20 +3438,23 @@ static void ggml_vk_load_shaders(vk_device& device) { } for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { - ggml_vk_create_pipeline(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); } else { ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); @@ -3251,6 +3575,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); for (auto &c : compiles) { c.wait(); @@ -3295,6 +3621,12 @@ static vk_device ggml_vk_get_device(size_t idx) { const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM"); device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr; + const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK"); + device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr; + + const char* GGML_VK_DISABLE_GRAPH_OPTIMIZE = getenv("GGML_VK_DISABLE_GRAPH_OPTIMIZE"); + device->disable_graph_optimize = GGML_VK_DISABLE_GRAPH_OPTIMIZE != nullptr; + bool fp16_storage = false; bool fp16_compute = false; bool maintenance4_support = false; @@ -3302,6 +3634,7 @@ static vk_device ggml_vk_get_device(size_t idx) { bool amd_shader_core_properties2 = false; bool pipeline_robustness = false; bool coopmat2_support = false; + bool pipeline_executable_properties_support = false; device->coopmat_support = false; device->integer_dot_product = false; bool bfloat16_support = false; @@ -3344,6 +3677,8 @@ static vk_device ggml_vk_get_device(size_t idx) { !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; #endif + } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) { + pipeline_executable_properties_support = true; } } @@ -3433,11 +3768,21 @@ static vk_device ggml_vk_get_device(size_t idx) { } device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; - device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && - (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); - + device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); +#ifdef __APPLE__ + // Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846) + if (device->vendor_id == VK_VENDOR_ID_AMD) { + device->subgroup_arithmetic = false; + } +#endif device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); + device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered); + + device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot); const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; @@ -3560,8 +3905,18 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_shader_integer_dot_product"); } + VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {}; + pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR; + if (pipeline_executable_properties_support) { + last_struct->pNext = (VkBaseOutStructure *)&pep_features; + last_struct = (VkBaseOutStructure *)&pep_features; + device_extensions.push_back("VK_KHR_pipeline_executable_properties"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); + device->pipeline_executable_properties_support = pipeline_executable_properties_support; + device->fp16 = device->fp16 && vk12_features.shaderFloat16; #if defined(VK_KHR_shader_bfloat16) @@ -3588,9 +3943,7 @@ static vk_device ggml_vk_get_device(size_t idx) { (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) && subgroup_size_control_features.subgroupSizeControl; - if (device->subgroup_size_control) { - device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; - } + device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; #if defined(VK_KHR_cooperative_matrix) device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; @@ -3891,6 +4244,19 @@ static vk_device ggml_vk_get_device(size_t idx) { device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr; + device->add_rms_fusion = !device->disable_fusion && + device->subgroup_arithmetic && + device->vendor_id != VK_VENDOR_ID_INTEL; + device->partials_binding_alignment = + std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment); + + device->mmvq_mode = 0; + if (getenv("GGML_VK_DISABLE_MMVQ")) { + device->mmvq_mode = -1; + } else if (getenv("GGML_VK_FORCE_MMVQ")) { + device->mmvq_mode = 1; + } + return device; } @@ -4055,10 +4421,10 @@ static void ggml_vk_print_gpu_info(size_t idx) { } } -static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); +static bool ggml_vk_instance_validation_ext_available(); static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); - static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions); +static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev); static void ggml_vk_instance_init() { if (vk_instance_initialized) { @@ -4066,6 +4432,9 @@ static void ggml_vk_instance_init() { } VK_LOG_DEBUG("ggml_vk_instance_init()"); + // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- + VULKAN_HPP_DEFAULT_DISPATCHER.init(vkGetInstanceProcAddr); + uint32_t api_version = vk::enumerateInstanceVersion(); if (api_version < VK_API_VERSION_1_2) { @@ -4076,7 +4445,7 @@ static void ggml_vk_instance_init() { vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version }; const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); - const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions); + const bool validation_ext = ggml_vk_instance_validation_ext_available(); #ifdef __APPLE__ const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); #endif @@ -4129,15 +4498,19 @@ static void ggml_vk_instance_init() { vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT"); vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT"); vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT"); - } vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; + // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- + VULKAN_HPP_DEFAULT_DISPATCHER.init(vk_instance.instance); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES"); if (devices_env != nullptr) { - size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); + size_t num_available_devices = devices.size(); std::string devices(devices_env); std::replace(devices.begin(), devices.end(), ',', ' '); @@ -4152,8 +4525,6 @@ static void ggml_vk_instance_init() { vk_instance.device_indices.push_back(tmp); } } else { - std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); - // If no vulkan devices are found, return early if (devices.empty()) { GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); @@ -4169,7 +4540,7 @@ static void ggml_vk_instance_init() { new_driver.pNext = &new_id; devices[i].getProperties2(&new_props); - if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) { + if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) { // Check if there are two physical devices corresponding to the same GPU auto old_device = std::find_if( vk_instance.device_indices.begin(), @@ -4239,7 +4610,7 @@ static void ggml_vk_instance_init() { } } - // If no dedicated GPUs found, fall back to the first non-CPU device. + // If no GPUs found, fall back to the first non-CPU device. // If only CPU devices are available, return without devices. if (vk_instance.device_indices.empty()) { for (size_t i = 0; i < devices.size(); i++) { @@ -4258,6 +4629,19 @@ static void ggml_vk_instance_init() { GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size()); for (size_t i = 0; i < vk_instance.device_indices.size(); i++) { + vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]]; + std::vector extensionprops = vkdev.enumerateDeviceExtensionProperties(); + + bool membudget_supported = false; + for (const auto & ext : extensionprops) { + if (strcmp(VK_EXT_MEMORY_BUDGET_EXTENSION_NAME, ext.extensionName) == 0) { + membudget_supported = true; + break; + } + } + + vk_instance.device_supports_membudget.push_back(membudget_supported); + ggml_vk_print_gpu_info(i); } } @@ -4404,9 +4788,22 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) { VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); - GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16); + GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16 || b_type == GGML_TYPE_Q8_1); GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols); + if (b_type == GGML_TYPE_Q8_1) { + switch (a_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + break; + default: + return nullptr; + } + } + switch (a_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: @@ -4438,7 +4835,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * // heuristic to choose workgroup size uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP; - if (ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { // Prefer larger workgroups when M is small, to spread the work out more // and keep more SMs busy. // q6_k seems to prefer small workgroup size even for "medium" values of M. @@ -4453,6 +4850,13 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * } } + if (b_type == GGML_TYPE_Q8_1) { + if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + } + return ctx->device->pipeline_dequant_mul_mat_vec_q8_1_f32[dmmv_wg][a_type][num_cols-1]; + } + return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1]; } @@ -4522,7 +4926,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co } static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { - VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); + VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()"); GGML_ASSERT(b_type == GGML_TYPE_F32); switch (a_type) { @@ -4625,8 +5029,8 @@ static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_ static void * ggml_vk_host_malloc(vk_device& device, size_t size) { VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")"); vk_buffer buf = ggml_vk_create_buffer(device, size, - vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, - vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) { fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n", @@ -4857,7 +5261,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont } } - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); return; } @@ -4872,7 +5276,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size); VkBufferCopy buf_copy{ 0, offset, copy_size }; - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); for (uint64_t i3 = 0; i3 < ne3; i3++) { @@ -4926,7 +5330,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz } } - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(nullptr, subctx); subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); return; } @@ -4947,7 +5351,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz offset, copy_size}; - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(nullptr, subctx); vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); if (width == spitch) { @@ -5027,7 +5431,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size if (buf != nullptr) { // Memory is pinned, use as staging buffer - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(nullptr, subctx); subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices); return; @@ -5044,7 +5448,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size vk_buffer& staging_buffer = src->device->sync_staging; - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(nullptr, subctx); subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices); deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); @@ -5234,13 +5638,16 @@ static void ggml_vk_matmul( uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t padded_n) { VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); - ggml_vk_sync_buffers(subctx); if (split_k == 1) { const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch }); return; } + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + GGML_ASSERT(batch_stride_d == m * n); // Round the split size up to a multiple of 256 (k-quant alignment) @@ -5250,9 +5657,10 @@ static void ggml_vk_matmul( const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; // Make sure enough workgroups get assigned for split k to work ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 }); + ctx->prealloc_split_k_need_sync = true; } static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { @@ -5297,7 +5705,6 @@ static void ggml_vk_matmul_id( "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); - ggml_vk_sync_buffers(subctx); const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, nei0, nei1, nbi1, ne11, padded_n }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as }); @@ -5350,6 +5757,20 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f32_bf16; } } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_i32; + } else { + return ctx->device->pipeline_cpy_f32_i32; + } + } + if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_i32_f32; + } else { + return ctx->device->pipeline_cpy_i32_f32; + } + } if (src->type == GGML_TYPE_F32) { switch (to) { case GGML_TYPE_Q4_0: @@ -5428,27 +5849,27 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; init_pushconst_fastdiv(pc); - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); } -static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { +static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type, bool use_x4_blocks) { switch(type) { case GGML_TYPE_Q8_1: - return ctx->device->pipeline_quantize_q8_1; + return use_x4_blocks ? ctx->device->pipeline_quantize_q8_1_x4 : ctx->device->pipeline_quantize_q8_1; default: std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl; GGML_ABORT("fatal error"); } } -static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) { +static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne, bool use_x4_blocks = false) { VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")"); - vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + vk_pipeline pipeline = use_x4_blocks ? ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true) : ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false); - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 }); + ggml_vk_sync_buffers(ctx, subctx); } static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -5565,12 +5986,15 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT if (quantize_y) { - to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true); } if (dryrun) { const uint64_t x_sz_upd = x_sz * ne02 * ne03; - const uint64_t y_sz_upd = y_sz * ne12 * ne13; + uint64_t y_sz_upd = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; + } const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; if ( (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || @@ -5636,25 +6060,47 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); } else if (quantize_y) { d_Y = ctx->prealloc_y; - GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)); + GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; GGML_ASSERT(qy_sz == y_sz); } + if (x_non_contig || qx_needs_dequant) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (x_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); } else if (qx_needs_dequant) { const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + ggml_vk_sync_buffers(ctx, subctx); } if (y_non_contig) { - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } } if (quantize_y) { - ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13); + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } } uint32_t stride_batch_x = ne00*ne01; @@ -5668,15 +6114,72 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } + uint32_t y_sz_total = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; + } + // compute ggml_vk_matmul( ctx, subctx, pipeline, - { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total }, { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n ); // NOLINT + + if (x_non_contig || qx_needs_dequant) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig || quantize_y) { + ctx->prealloc_y_need_sync = true; + } +} + +// Device tuning +static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_t n, uint32_t k, ggml_type src0_type) { + if (device->mmvq_mode == 1) { + return true; + } else if (device->mmvq_mode == -1) { + return false; + } + + // MMVQ is generally good for batches + if (n > 1) { + return true; + } + + switch (device->vendor_id) { + case VK_VENDOR_ID_NVIDIA: + switch (src0_type) { + case GGML_TYPE_Q8_0: + return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; + default: + return true; + } + case VK_VENDOR_ID_AMD: + switch (src0_type) { + case GGML_TYPE_Q8_0: + return device->architecture == vk_device_architecture::AMD_GCN; + default: + return true; + } + case VK_VENDOR_ID_INTEL: + switch (src0_type) { + // From tests on A770 Linux, may need more tuning + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q5_1: + return false; + default: + return true; + } + default: + return true; + } + + GGML_UNUSED(m); + GGML_UNUSED(k); } static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -5733,22 +6236,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; - - const bool qx_needs_dequant = x_non_contig; - const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; - - // Not implemented - GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - - const uint64_t x_ne = ne01 * ne00; - const uint64_t y_ne = ne11 * ne10; - const uint64_t d_ne = ne11 * ne01; - - const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); - const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); - const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; - const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; - const uint64_t d_sz = sizeof(float) * d_ne; + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type); vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; @@ -5760,14 +6248,47 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } - vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00); + + // Check for mmq first + vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11, ne20, ne00) : nullptr; + vk_pipeline to_q8_1 = nullptr; + + if (dmmv == nullptr) { + // Fall back to f16 dequant mul mat + dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00); + quantize_y = false; + } + + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true); + } + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT GGML_ASSERT(dmmv != nullptr); + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne11 * ne01; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); + const uint64_t d_sz = sizeof(float) * d_ne; + if (dryrun) { const uint64_t x_sz_upd = x_sz * ne02 * ne03; - const uint64_t y_sz_upd = y_sz * ne12 * ne13; + uint64_t y_sz_upd = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; + } if ( (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { @@ -5776,7 +6297,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { ctx->prealloc_size_x = x_sz_upd; } - if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { ctx->prealloc_size_y = y_sz_upd; } @@ -5787,6 +6308,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (qy_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); return; } @@ -5817,6 +6341,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& } if (qy_needs_dequant) { d_Y = ctx->prealloc_y; + } else if (quantize_y) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; @@ -5824,12 +6351,35 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& } if (x_non_contig) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); } if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + if (quantize_y) { + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } } // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride @@ -5855,16 +6405,28 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& groups_x = CEIL_DIV(groups_x, groups_z); } + // TODO: Clean up this whole sz * ne_2 * ne_3 thing, it hasn't been necessary for a long time + uint32_t y_sz_total = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; + } + // compute const vk_mat_vec_push_constants pc = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, stride_batch_x, stride_batch_y, stride_batch_d, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, - { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, + { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz_total }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); + + if (x_non_contig) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig || quantize_y) { + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -5951,7 +6513,6 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c workgroups_z /= gqa_ratio; } - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z }); } @@ -6038,7 +6599,6 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con // compute const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); } @@ -6089,7 +6649,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t nei0 = ids->ne[0]; const uint64_t nei1 = ids->ne[1]; - GGML_ASSERT(nei0 * nei1 <= 4096); const uint32_t nbi1 = ids->nb[1]; const uint32_t nbi2 = ids->nb[2]; @@ -6250,16 +6809,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& GGML_ASSERT(qy_sz == y_sz); } + if (x_non_contig || qx_needs_dequant) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (x_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); } else if (qx_needs_dequant) { const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + ggml_vk_sync_buffers(ctx, subctx); } if (y_non_contig) { - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } } uint32_t stride_batch_x = ne00*ne01; @@ -6282,6 +6855,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& stride_batch_x, stride_batch_y, ne20*ne21, n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n ); // NOLINT + + if (x_non_contig || qx_needs_dequant) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig) { + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { @@ -6441,13 +7021,27 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte GGML_ASSERT(qy_sz == y_sz); } + if (x_non_contig) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (x_non_contig) { GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); } if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } } uint32_t stride_batch_y = ne10*ne11; @@ -6472,11 +7066,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21), (uint32_t)nei0, (uint32_t)ne11, }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } }, pc, { groups_x, (uint32_t)nei0, groups_z }); + + if (x_non_contig) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig) { + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { @@ -6484,30 +7084,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); } else { - // Split based on number of ids, to fit in shared memory - const uint32_t nei0 = (uint32_t)src2->ne[0]; - const uint32_t nei1 = (uint32_t)src2->ne[1]; - - GGML_ASSERT(nei0 <= 4096); - const uint32_t split_size = std::min(nei1, 4096u / nei0); - - ggml_tensor src1_copy = *src1; - ggml_tensor src2_copy = *src2; - ggml_tensor dst_copy = *dst; - - for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) { - const uint32_t n_tokens = std::min(split_size, nei1 - token_start); - - src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2]; - src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1]; - dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2]; - - src1_copy.ne[2] = n_tokens; - src2_copy.ne[1] = n_tokens; - dst_copy.ne[2] = n_tokens; - - ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun); - } + ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); } } @@ -6540,18 +7117,21 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t Br = coopmat1_flash_attention_num_large_rows; const uint32_t Bc = scalar_flash_attention_Bc; + const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16); + const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; const uint32_t tmpsh = wg_size * sizeof(float); const uint32_t tmpshv4 = wg_size * 4 * acctype; - const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4; + const uint32_t qstride = hsk_pad / 4 + 2; + const uint32_t Qf = Br * qstride * f16vec4; const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; const uint32_t sfsh = Bc * sfshstride * acctype; - const uint32_t kshstride = hsk / 4 + 2; + const uint32_t kshstride = hsk_pad / 4 + 2; const uint32_t ksh = Bc * kshstride * f16vec4; const uint32_t slope = Br * sizeof(float); @@ -6662,7 +7242,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_y /= N; } - vk_pipeline *pipelines; bool small_rows = N <= get_fa_num_small_rows(path); // coopmat1 does not actually support "small rows" (it needs 16 rows). @@ -6676,43 +7255,42 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx path = FA_SCALAR; } - // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory - if (path == FA_SCALAR && - !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) { - small_rows = true; - } - - bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; - - FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]); - - switch (path) { - case FA_SCALAR: - pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0]; - break; - case FA_COOPMAT1: - pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0]; - break; - case FA_COOPMAT2: - pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0]; - break; - default: - GGML_ASSERT(0); + // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory + if (path == FA_SCALAR && + !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) { + small_rows = true; } - assert(pipelines); const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); - bool aligned = (KV % pipelines[1]->align) == 0 && + uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); + bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned. + if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) { + aligned = false; + } // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0); - vk_pipeline pipeline = pipelines[aligned]; + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc); + + vk_pipeline pipeline = nullptr; + + auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; + auto it = pipelines.find(fa_pipeline_state); + if (it != pipelines.end()) { + pipeline = it->second; + } else { + pipelines[fa_pipeline_state] = pipeline = std::make_shared(); + } + assert(pipeline); uint32_t split_kv = KV; @@ -6728,7 +7306,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (split_k > 1) { // Try to evenly split KV into split_k chunks, but it needs to be a multiple // of "align", so recompute split_k based on that. - split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align); + split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); split_k = CEIL_DIV(KV, split_kv); workgroups_x = split_k; } @@ -6852,9 +7430,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx mask_n_head_log2, m0, m1, gqa_ratio, split_kv, split_k }; - ggml_vk_sync_buffers(subctx); - if (split_k > 1) { + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, @@ -6870,7 +7450,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // cancel out the divide by wg_denoms[0]. pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, { @@ -6879,6 +7459,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); + ctx->prealloc_split_k_need_sync = true; } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { @@ -6921,7 +7502,7 @@ static std::array ggml_vk_get_conv_elements(const ggml_tensor *dst) return elements; } -static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { +static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) { switch (op) { case GGML_OP_GET_ROWS: GGML_ASSERT(src1->type == GGML_TYPE_I32); @@ -6950,10 +7531,19 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_ADD: { if (ctx->num_additional_fused_ops > 0) { - return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops]; + if (ctx->do_add_rms_partials) { + return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops]; + } else { + return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops]; + } + } + if (ctx->do_add_rms_partials) { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } else { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; } - auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; - return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; } case GGML_OP_SUB: { @@ -7076,7 +7666,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_RMS_NORM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32; + if (ctx->do_add_rms_partials) { + return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32; + } else { + return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32; + } } return nullptr; case GGML_OP_RMS_NORM_BACK: @@ -7097,6 +7691,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_EXP: + return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_SILU: return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU: @@ -7111,6 +7707,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_SIGMOID: return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_HARDSIGMOID: + return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_HARDSWISH: + return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16]; default: break; } @@ -7207,6 +7807,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_sum_rows_f32; } @@ -7229,6 +7830,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_im2col_f32_f16; } return nullptr; + case GGML_OP_IM2COL_3D: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_im2col_3d_f32; + } + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_im2col_3d_f32_f16; + } + return nullptr; case GGML_OP_TIMESTEP_EMBEDDING: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_timestep_embedding_f32; @@ -7306,6 +7915,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } else if (ggml_is_contiguous_channels(src1)) { return ctx->device->pipeline_conv2d_dw_cwhn_f32; } + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + if (ggml_is_contiguous(src1)) { + return ctx->device->pipeline_conv2d_dw_whcn_f16_f32; + } else if (ggml_is_contiguous_channels(src1)) { + return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32; + } } return nullptr; default: @@ -7338,7 +7953,11 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_RMS_NORM: case GGML_OP_CONV_2D_DW: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_SET_ROWS: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: return true; default: return false; @@ -7373,6 +7992,36 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src2); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src0); + GGML_UNUSED(src2); +} + template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); @@ -7523,10 +8172,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); if (op_supports_incontiguous) { - x_sz = ggml_nbytes(src0); - y_sz = use_src1 ? ggml_nbytes(src1) : 0; - z_sz = use_src2 ? ggml_nbytes(src2) : 0; - d_sz = ggml_nbytes(dst); + x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); + y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; + z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; + d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); if (x_buf_offset + x_sz >= d_X->size) { x_sz = VK_WHOLE_SIZE; @@ -7554,6 +8203,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGMAX: { const uint32_t nr = ggml_nrows(src0); @@ -7566,7 +8216,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } break; case GGML_OP_RMS_NORM: - elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + if (ctx->do_add_rms_partials) { + // Run one element per thread, 128 threads per workgroup + elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 }; + } else { + elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + } break; case GGML_OP_SUM: @@ -7585,6 +8240,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co break; case GGML_OP_GET_ROWS: elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); break; case GGML_OP_ARGSORT: elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; @@ -7605,6 +8262,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { OW * KW * KH, OH, batch * IC }; } break; + case GGML_OP_IM2COL_3D: + { + const uint32_t IC = ((const uint32_t *)(dst->op_params))[9]; + + const uint32_t N = ne13 / IC; + + const uint32_t KD = ne02; + const uint32_t KH = ne01; + const uint32_t KW = ne00; + + const uint32_t OD = ned3 / N; + const uint32_t OH = ned2; + const uint32_t OW = ned1; + + const uint32_t IC_KD_KH_KW = IC*KD*KH*KW; + const uint32_t N_OD_OH = N*OD*OH; + + elements = { IC_KD_KH_KW, OW, N_OD_OH }; + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + } break; case GGML_OP_TIMESTEP_EMBEDDING: { const uint32_t dim = dst->op_params[0]; @@ -7715,7 +8392,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } - if (op == GGML_OP_GLU) { + if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) { + vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X; + size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { vk_subbuffer{ d_X, x_buf_offset, x_sz }, + vk_subbuffer{ d_Y, y_buf_offset, y_sz }, + vk_subbuffer{ d_D, d_buf_offset, d_sz }, + vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE }, + }, pc, elements); + } else if (op == GGML_OP_GLU) { // Empty src1 is possible in glu, but the shader needs a buffer vk_subbuffer subbuf_y; if (use_src1) { @@ -7724,7 +8410,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co subbuf_y = { d_X, 0, x_sz }; } - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_SOFT_MAX) { // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer @@ -7742,7 +8427,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co subbuf_z = { d_X, 0, x_sz }; } - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { // Empty src2 is possible in rope, but the shader needs a buffer @@ -7753,30 +8437,23 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co subbuf_z = { d_X, 0, x_sz }; } - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); - } else if (op == GGML_OP_IM2COL) { + } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) { // im2col uses only src1 and dst buffers - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_COUNT_EQUAL) { - ggml_vk_sync_buffers(subctx); // count_equal assumes that destination buffer is initialized with zeroes ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_OPT_STEP_SGD) { // OPT_STEP_SGD works on src0, it does not need dst - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements); } else if (use_src2) { - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (use_src1) { - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else { - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } } @@ -7825,7 +8502,7 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor *tensors[MAX_PARAMETER_COUNT]; uint32_t num_srcs = ctx->num_additional_fused_ops + 2; uint32_t num_tensors = num_srcs + 1; - GGML_ASSERT(num_tensors <= MAX_PARAMETER_COUNT); + GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT); tensors[0] = first_node->src[0]; tensors[1] = first_node->src[1]; @@ -7852,8 +8529,9 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float); pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float); } + pc.rms_partials = ctx->do_add_rms_partials; - vk_pipeline pipeline = ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops]; + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op); if (pipeline == nullptr) { std::cerr << "ggml_vulkan: Error: Missing multi_add"; @@ -7891,6 +8569,10 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, buf[i] = buf[0]; offset[i] = 0; } + if (ctx->do_add_rms_partials) { + buf[num_tensors] = ctx->prealloc_add_rms_partials; + offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset; + } std::array elements; @@ -7903,7 +8585,7 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, elements = { ne, 1, 1 }; } - ggml_vk_sync_buffers(subctx); + static_assert(MAX_PARAMETER_COUNT == 12); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE }, @@ -7914,6 +8596,10 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE }, vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE }, vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE }, }, pc, elements); } @@ -7928,7 +8614,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - 0.0f, 0.0f, 0, + 0.0f, 0.0f, ctx->do_add_rms_partials, }, dryrun); } @@ -8016,8 +8702,6 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; } - ggml_vk_sync_buffers(subctx); - vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 }; bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false }; @@ -8155,8 +8839,6 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context; ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context; - ggml_vk_sync_buffers(subctx); - vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr; size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0; bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false; @@ -8302,7 +8984,7 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con } static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun); } @@ -8390,19 +9072,39 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); } +static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) { + const uint32_t ne = (uint32_t)node->ne[0]; + const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0]; + const uint32_t num_partials = CEIL_DIV(ne, denom); + return num_partials; +} + +static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) { + const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node); + const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment); + return num_bytes; +} + static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); + uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - op_params[0], 0.0f, 0, + op_params[0], 0.0f, (int32_t)param3, }, dryrun); + + if (ctx->do_add_rms_partials) { + ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0); + ctx->do_add_rms_partials = false; + } } static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -8493,7 +9195,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun); } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { @@ -8540,11 +9242,19 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c } static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0)); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun); } static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun); +} + +static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + p.weight = 1.0f / (float)src0->ne[0]; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun); } static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { @@ -8589,6 +9299,66 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co }, dryrun); } +static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + + vk_op_im2col_3d_push_constants pc {}; + + pc.nb10 = nb10 / ggml_type_size(src1->type); + pc.nb11 = nb11 / ggml_type_size(src1->type); + pc.nb12 = nb12 / ggml_type_size(src1->type); + pc.nb13 = nb13 / ggml_type_size(src1->type); + pc.s0 = s0; + pc.s1 = s1; + pc.s2 = s2; + pc.p0 = p0; + pc.p1 = p1; + pc.p2 = p2; + pc.d0 = d0; + pc.d1 = d1; + pc.d2 = d2; + pc.IW = IW; + pc.IH = IH; + pc.ID = ID; + pc.IC = IC; + pc.KW = KW; + pc.OH = OH; + pc.KD_KH_KW = KD*KH*KW; + pc.KH_KW = KH*KW; + pc.IC_KD_KH_KW = IC*KD*KH*KW; + pc.N_OD_OH = N*OD*OH; + pc.OD_OH = OD*OH; + pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun); +} + static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const uint32_t dim = dst->op_params[0]; const uint32_t max_period = dst->op_params[1]; @@ -8888,7 +9658,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t if (ctx->prealloc_split_k != nullptr) { ggml_vk_destroy_buffer(ctx->prealloc_split_k); } - ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal}); } } @@ -8898,9 +9668,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_pipeline_allocate_descriptor_sets(ctx); - vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); - vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); - vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne); Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne); @@ -9126,8 +9896,8 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); float * x = (float *) malloc(x_sz); void * qx = malloc(qx_sz); - vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); - vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, {vk::MemoryPropertyFlagBits::eDeviceLocal}); float * x_ref = (float *) malloc(x_sz); ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16); @@ -9232,8 +10002,8 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ // float * x = (float *) malloc(x_sz); // block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz); // block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz); -// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); -// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); +// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); +// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); // // for (size_t i = 0; i < ne; i++) { // x[i] = rand() / (float)RAND_MAX; @@ -9380,10 +10150,10 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, float * x = (float *) malloc(x_sz); float * y = (float *) malloc(y_sz); void * qx = malloc(qx_sz); - vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); - vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); - vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); - vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); float * d = (float *) malloc(d_sz); float * d_chk = (float *) malloc(d_sz); @@ -9410,7 +10180,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, if (ctx->prealloc_split_k != nullptr) { ggml_vk_destroy_buffer(ctx->prealloc_split_k); } - ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal}); } } if (mmq) { @@ -9672,6 +10442,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k); } + if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")"); + // Resize buffer + if (ctx->prealloc_add_rms_partials != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials); + } + ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials); + } } static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); @@ -9702,6 +10480,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr return false; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_ERF: @@ -9709,6 +10488,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: break; default: return false; @@ -9727,10 +10508,23 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr return false; } break; + case GGML_OP_ADD: + { + int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops; + if (next_node_idx < cgraph->n_nodes && + cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM && + cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] && + ggml_nrows(cgraph->nodes[next_node_idx]) == 1 && + ctx->device->add_rms_fusion) { + if (dryrun) { + ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]); + } + ctx->do_add_rms_partials = true; + } + } break; case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_GET_ROWS: - case GGML_OP_ADD: case GGML_OP_ADD_ID: case GGML_OP_ACC: case GGML_OP_SUB: @@ -9766,9 +10560,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: @@ -9835,9 +10631,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: @@ -9850,6 +10648,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr // do the only thing needed for the dryrun. vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (node->op == GGML_OP_RMS_NORM) { + ctx->do_add_rms_partials = false; + } return false; } default: @@ -9857,6 +10658,80 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } + if (!dryrun) { + // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers + // to synchronize them. This handles most "normal" synchronization when computing the graph, and when + // there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers + // outside of this logic. When a node uses one of the prealloc buffers for something like + // dequantization or split_k, additional synchronization is needed between those passes. + bool need_sync = false; + + // Check whether "node" requires synchronization. The node requires synchronization if it + // overlaps in memory with another unsynchronized node and at least one of them is a write. + // Destination nodes are checked against both the written/read lists. Source nodes are only + // checked against the written list. Two nodes overlap in memory if they come from the same + // buffer and the tensor or view ranges overlap. + auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector &unsynced_nodes) -> bool { + if (unsynced_nodes.size() == 0) { + return false; + } + auto n_base = vk_tensor_offset(node) + node->view_offs; + auto n_size = ggml_nbytes(node); + ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context; + vk_buffer a_buf = a_buf_ctx->dev_buffer; + for (auto &other : unsynced_nodes) { + ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context; + vk_buffer o_buf = o_buf_ctx->dev_buffer; + if (a_buf == o_buf) { + auto o_base = vk_tensor_offset(other) + other->view_offs; + auto o_size = ggml_nbytes(other); + + if ((o_base <= n_base && n_base < o_base + o_size) || + (n_base <= o_base && o_base < n_base + n_size)) { + return true; + } + } + } + return false; + }; + + // For all fused ops, check if the destination node or any of the source + // nodes require synchronization. + for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) { + const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; + if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) { + need_sync = true; + break; + } + for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { + if (!cur_node->src[j]) { + continue; + } + if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) { + need_sync = true; + break; + } + } + } + if (need_sync) { + ctx->unsynced_nodes_written.clear(); + ctx->unsynced_nodes_read.clear(); + ggml_vk_sync_buffers(ctx, compute_ctx); + } + // Add the last fused node and all fused source nodes to the unsynchronized list. + const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; + ctx->unsynced_nodes_written.push_back(last_node); + for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { + const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; + for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { + if (!cur_node->src[j]) { + continue; + } + ctx->unsynced_nodes_read.push_back(cur_node->src[j]); + } + } + } + switch (node->op) { case GGML_OP_REPEAT: ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); @@ -9979,6 +10854,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_ERF: @@ -9986,6 +10862,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); break; default: @@ -10037,6 +10915,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SUM_ROWS: ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_MEAN: + ggml_vk_mean(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_ARGMAX: ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun); @@ -10049,6 +10931,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_IM2COL: ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_IM2COL_3D: + ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_TIMESTEP_EMBEDDING: ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); @@ -10196,9 +11082,11 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: @@ -10215,6 +11103,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * break; case GGML_OP_UNARY: switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_ERF: @@ -10222,6 +11111,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: buf = tensor->buffer; break; default: @@ -10311,6 +11202,11 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ggml_vk_pool_free(ctx, buffer); } ctx->gc.temp_buffers.clear(); + ctx->prealloc_y_last_pipeline_used = {}; + + ctx->unsynced_nodes_written.clear(); + ctx->unsynced_nodes_read.clear(); + ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); @@ -10346,6 +11242,7 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ggml_vk_destroy_buffer(ctx->prealloc_x); ggml_vk_destroy_buffer(ctx->prealloc_y); ggml_vk_destroy_buffer(ctx->prealloc_split_k); + ctx->prealloc_y_last_pipeline_used = nullptr; for (auto& buffer : ctx->buffer_pool) { ggml_vk_destroy_buffer(buffer); @@ -10829,6 +11726,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast(&dul)); } + ctx->prealloc_size_add_rms_partials = 0; + ctx->prealloc_size_add_rms_partials_offset = 0; + ctx->do_add_rms_partials = false; + uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { if (!ctx->device->disable_fusion) { @@ -10894,6 +11795,22 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0); } + ctx->prealloc_y_last_pipeline_used = nullptr; + ctx->prealloc_y_last_tensor_used = nullptr; + + if (ctx->prealloc_size_add_rms_partials) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + // initialize partial sums to zero. + ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials); + ggml_vk_sync_buffers(ctx, compute_ctx); + } + // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB // (and scaled down based on model size, so smaller models submit earlier). @@ -10996,6 +11913,131 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg UNUSED(backend); } +// Sort the graph for improved parallelism. +static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph) +{ + VK_LOG_DEBUG("ggml_vk_graph_optimize(" << graph->n_nodes << " nodes)"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + if (ctx->device->disable_graph_optimize) { + return; + } + + auto const &is_empty = [](ggml_tensor * node) -> bool { + return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; + }; + + auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + if (dst->src[s] == src) { + return true; + } + } + // implicit dependency if they view the same tensor + const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst; + const ggml_tensor *src2 = src->view_src ? src->view_src : src; + if (dst2 == src2) { + return true; + } + return false; + }; + + // This function tries to reorder the graph to allow nodes to run in parallel. + // This helps with small batches, but for large batches its a slowdown, probably + // due to cache contention. So only reorder if the majority of nodes have few rows. + int num_small_nodes = 0; + int num_counted_nodes = 0; + for (int i = 0; i < graph->n_nodes; ++i) { + if (!is_empty(graph->nodes[i]) && + graph->nodes[i]->op != GGML_OP_SET_ROWS) { + if (ggml_nrows(graph->nodes[i]) <= 8) { + num_small_nodes++; + } + num_counted_nodes++; + } + } + if (num_small_nodes < num_counted_nodes / 2) { + return; + } + + std::vector new_order; + std::vector used(graph->n_nodes, false); + int first_unused = 0; + while (first_unused < graph->n_nodes) { + std::vector current_set; + + // First, grab the next unused node. + current_set.push_back(first_unused); + + // Loop through the next N nodes. Grab any that don't depend on other nodes that + // haven't already been run. Nodes that have already been run have used[i] set + // to true. Allow nodes that depend on the previous node if it's a fusion pattern + // that we support (e.g. RMS_NORM + MUL). + // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes. + // The goal is to not interleave real and view nodes in a way that breaks fusion. + const int NUM_TO_CHECK = 20; + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL)) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + // Second pass grabs view nodes. + // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add). + if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) { + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (!is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end(); + // skip views whose srcs haven't been processed. + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !c_in_current_set) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + } + + // Push the current set into new_order + for (auto c : current_set) { + new_order.push_back(graph->nodes[c]); + used[c] = true; + } + while (first_unused < graph->n_nodes && used[first_unused]) { + first_unused++; + } + } + // Replace the graph with the new order. + for (int i = 0; i < graph->n_nodes; ++i) { + graph->nodes[i] = new_order[i]; + } +} + // TODO: enable async and synchronize static ggml_backend_i ggml_backend_vk_interface = { /* .get_name = */ ggml_backend_vk_name, @@ -11011,6 +12053,7 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .graph_compute = */ ggml_backend_vk_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_optimize = */ ggml_vk_graph_optimize, }; static ggml_guid_t ggml_backend_vk_guid() { @@ -11050,18 +12093,81 @@ void ggml_backend_vk_get_device_description(int device, char * description, size void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; + vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops; + vk::PhysicalDeviceMemoryProperties2 memprops = {}; + bool membudget_supported = vk_instance.device_supports_membudget[device]; + + if (membudget_supported) { + memprops.pNext = &budgetprops; + } + vkdev.getMemoryProperties2(&memprops); - vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); + for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) { + const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i]; - for (const vk::MemoryHeap& heap : memprops.memoryHeaps) { if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { *total = heap.size; - *free = heap.size; + + if (membudget_supported && i < budgetprops.heapUsage.size()) { + *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i]; + } else { + *free = heap.size; + } + break; + } + } +} + +static vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) { + GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size()); + + vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]]; + + vk::PhysicalDeviceProperties2 props = {}; + device.getProperties2(&props); + + return props.properties.deviceType; +} + +static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { + GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size()); + + vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]]; + + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool ext_support = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_EXT_pci_bus_info", properties.extensionName) == 0) { + ext_support = true; break; } } + + if (!ext_support) { + return ""; + } + + vk::PhysicalDeviceProperties2 props = {}; + vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_info = {}; + + props.pNext = &pci_bus_info; + + device.getProperties2(&props); + + const uint32_t pci_domain = pci_bus_info.pciDomain; + const uint32_t pci_bus = pci_bus_info.pciBus; + const uint32_t pci_device = pci_bus_info.pciDevice; + const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning + + char pci_bus_id[16] = {}; + snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function); + + return std::string(pci_bus_id); } ////////////////////////// @@ -11070,6 +12176,8 @@ struct ggml_backend_vk_device_context { size_t device; std::string name; std::string description; + bool is_integrated_gpu; + std::string pci_bus_id; }; static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { @@ -11098,14 +12206,18 @@ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(gg } static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) { - UNUSED(dev); - return GGML_BACKEND_DEVICE_TYPE_GPU; + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + + return ctx->is_integrated_gpu ? GGML_BACKEND_DEVICE_TYPE_IGPU : GGML_BACKEND_DEVICE_TYPE_GPU; } static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); props->type = ggml_backend_vk_device_get_type(dev); + props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { /* .async = */ false, @@ -11125,6 +12237,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: @@ -11132,6 +12245,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && @@ -11223,8 +12338,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; auto device = ggml_vk_get_device(ctx->device); bool coopmat2 = device->coopmat2; - FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]); - if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) { + uint32_t HSK = op->src[1]->ne[0]; + uint32_t HSV = op->src[2]->ne[0]; + if ((HSK % 8) != 0 || (HSV % 8) != 0) { return false; } if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) { @@ -11367,6 +12483,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; } + if ( + (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) || + (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) + ) { + return true; + } + // We can handle copying from a type to the same type if it's // contiguous (memcpy). We use f16 or f32 shaders to do the copy, // so the type/block size must be a multiple of 4. @@ -11426,11 +12549,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: + return true; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_2D_DW: case GGML_OP_POOL_2D: @@ -11444,14 +12571,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; const vk_device& device = ggml_vk_get_device(ctx->device); - bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE; // Channel-contiguous format is not supported yet. return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && - ggml_is_contiguous(op)) && !is_Apple; + ggml_is_contiguous(op)); } default: return false; @@ -11524,6 +12650,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->device = i; ctx->name = GGML_VK_NAME + std::to_string(i); ctx->description = desc; + ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; + ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, @@ -11561,35 +12689,33 @@ ggml_backend_reg_t ggml_backend_vk_reg() { } // Extension availability -static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions) { +static bool ggml_vk_instance_validation_ext_available() { #ifdef GGML_VULKAN_VALIDATE - bool portability_enumeration_ext = false; - // Check for portability enumeration extension for MoltenVK support - for (const auto& properties : instance_extensions) { - if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { - return true; + // Check if validation layer provides the extension + const std::string layer_name = "VK_LAYER_KHRONOS_validation"; + for (const auto& layer : vk::enumerateInstanceLayerProperties()) { + if (layer_name == layer.layerName.data()) { + for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) { + if (strcmp("VK_EXT_validation_features", ext.extensionName.data()) == 0) { + return true; + } + } } } - if (!portability_enumeration_ext) { - std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; - } + + std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_validation_features not found." << std::endl; #endif return false; - - UNUSED(instance_extensions); } static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions) { #ifdef __APPLE__ - bool portability_enumeration_ext = false; // Check for portability enumeration extension for MoltenVK support for (const auto& properties : instance_extensions) { if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { return true; } } - if (!portability_enumeration_ext) { - std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; - } + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; #endif return false; @@ -11612,6 +12738,20 @@ static bool ggml_vk_instance_debug_utils_ext_available( UNUSED(instance_extensions); } +static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) { + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + vkGetPhysicalDeviceFeatures2(vkdev, &device_features2); + + return vk11_features.storageBuffer16BitAccess; +} + static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) { switch (props.vendorID) { case VK_VENDOR_ID_INTEL: @@ -11846,7 +12986,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else if (tensor->op == GGML_OP_CONCAT) { tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_UPSCALE) { - tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); + tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); } else if (tensor->op == GGML_OP_SCALE) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); @@ -11862,7 +13002,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const float * params = (const float *)tensor->op_params; tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); } else if (tensor->op == GGML_OP_PAD) { - tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]); + tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3], + tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]); } else if (tensor->op == GGML_OP_REPEAT) { tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); } else if (tensor->op == GGML_OP_REPEAT_BACK) { @@ -11924,6 +13065,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } } else if (tensor->op == GGML_OP_UNARY) { switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_EXP: + tensor_clone = ggml_exp(ggml_ctx, src_clone[0]); + break; case GGML_UNARY_OP_SILU: tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); break; @@ -11945,6 +13089,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_SIGMOID: tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_HARDSIGMOID: + tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_HARDSWISH: + tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]); + break; default: std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); @@ -11983,6 +13133,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SUM_ROWS) { tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_MEAN) { + tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_ARGMAX) { tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_COUNT_EQUAL) { @@ -11997,6 +13149,19 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const bool is_2D = tensor->op_params[6] == 1; tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + } else if (tensor->op == GGML_OP_IM2COL_3D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t s2 = tensor->op_params[2]; + const int32_t p0 = tensor->op_params[3]; + const int32_t p1 = tensor->op_params[4]; + const int32_t p2 = tensor->op_params[5]; + const int32_t d0 = tensor->op_params[6]; + const int32_t d1 = tensor->op_params[7]; + const int32_t d2 = tensor->op_params[8]; + const int32_t IC = tensor->op_params[9]; + + tensor_clone = ggml_im2col_3d(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type); } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { const int32_t dim = tensor->op_params[0]; const int32_t max_period = tensor->op_params[1]; @@ -12080,11 +13245,9 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { return; } - bool fused_rms_norm_mul = false; if (ctx->num_additional_fused_ops == 1 && tensor->op == GGML_OP_RMS_NORM && cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { - fused_rms_norm_mul = true; tensor = cgraph->nodes[tensor_idx + 1]; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp index 2b4085c4f82..00cf2dd62fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp @@ -1,20 +1,34 @@ #version 450 #extension GL_EXT_shader_16bit_storage : require +#if ADD_RMS +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif #include "types.comp" #include "generic_binary_head.comp" const uint num_threads = 256; +layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];}; + layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; +#if ADD_RMS +// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant +shared FLOAT_TYPE sumsh[num_threads]; +#endif + void main() { uint idx = get_idx(); + uint orig_idx = idx; // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation const uint num_iter = 2; + FLOAT_TYPE sum_sq = 0; + [[unroll]] for (uint i = 0; i < num_iter; ++i) { if (idx >= p.ne) { continue; @@ -22,8 +36,34 @@ void main() { uint i00, i01, i02, i03; get_indices(idx, i00, i01, i02, i03); - data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]); + sum_sq += sum*sum; + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum); idx += num_threads; } + +#if ADD_RMS + if (p.param3 != 0) { + // reduce the sum within each subgroup, then across subgroups + const uint NumSubgroups = num_threads / gl_SubgroupSize; + sum_sq = subgroupAdd(sum_sq); + if (gl_SubgroupInvocationID == 0) { + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) { + if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) { + sum_sq += sumsh[gl_SubgroupID + s]; + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + } + + if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) { + partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq; + } + } +#endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp index 48f6b65bc40..127c7b64240 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp @@ -29,7 +29,7 @@ void main() { uint qs = data_a[ib].qs[4 * ib32 + l]; const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l]; qs |= (qh << (8 - 2 * l)) & 0x300; - const uvec2 grid = iq2s_grid[qs & 511]; + const uvec2 grid = iq2s_grid[qs]; const u8vec4 grid0 = unpack8(grid.x); const u8vec4 grid1 = unpack8(grid.y); data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0)); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp index e370690bcb0..0ae9acd02a6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp @@ -33,7 +33,8 @@ void main() { [[unroll]] for (uint l = 0; l < 4; ++l) { const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit - const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]]; + const uint qs = data_a[ib].qs[8 * is + l]; + const uvec2 grid = iq2xxs_grid[qs]; const u8vec4 grid0 = unpack8(grid.x); const u8vec4 grid1 = unpack8(grid.y); data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp index c3f4bca5d95..e4f42be94c7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -22,15 +22,16 @@ void main() { const uint b_idx = 256 * ib + 32 * is; const float d = float(data_a[ib].d); - const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf)); + const float db = d * (1 + 2 * ((data_a[ib].scales[is / 2] >> (4 * (is % 2))) & 0xf)); // We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes. uint qh = data_a[ib].qh[is]; [[unroll]] for (uint l = 0; l < 8; ++l) { - uint qs = data_a[ib].qs[8 * is + l]; - uint gidx = qs | ((qh << (8 - l)) & 256); - uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1)); - u8vec4 grid = unpack8(iq3s_grid[gidx]); + const uint iqs = 8 * is + l; + const uint qs = data_a[ib].qs[iqs]; + const uint gidx = qs | ((qh << (8 - l)) & 256); + const uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1)); + const u8vec4 grid = unpack8(iq3s_grid[gidx]); data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0)); data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0)); data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0)); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp index a92b82961af..19c7fdeefce 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp @@ -35,8 +35,10 @@ void main() { const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); // Restore parity bit. const uint sign8 = sign7 | (bitCount(sign7) << 7); - const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]); - const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]); + const uint qs0 = data_a[ib].qs[8 * is + 2 * l]; + const uint qs1 = data_a[ib].qs[8 * is + 2 * l + 1]; + const u8vec4 grid0 = unpack8(iq3xxs_grid[qs0]); + const u8vec4 grid1 = unpack8(iq3xxs_grid[qs1]); data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp new file mode 100644 index 00000000000..abecd2d3dcf --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(exp(float(data_a[i]))); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index d40848e15fe..482445c6fea 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -334,6 +334,9 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Of[r][d] *= Lfrcp[r]; +#if defined(ACC_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX)); +#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp index b57c9dcfc4e..f73e17e1fa8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -9,6 +9,10 @@ layout (constant_id = 4) const uint32_t HSV = 32; layout (constant_id = 5) const uint32_t Clamp = 0; layout (constant_id = 6) const uint32_t D_split = 16; +// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths +const uint32_t HSK_pad = (HSK + 15) & ~15; +const uint32_t HSV_pad = (HSV + 15) & ~15; + layout (push_constant) uniform parameter { uint32_t N; uint32_t KV; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 81cc3f81fce..63b32171b0c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -46,14 +46,14 @@ const uint32_t MatBc = 16; shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; -const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4 +const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 shared f16vec4 Qf[Br * qstride]; // Avoid padding for hsk==256 to make it fit in 48KB shmem. const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br; shared ACC_TYPE sfsh[Bc * sfshstride]; -const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4 +const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4 shared f16vec4 ksh[Bc * kshstride]; shared float slope[Br]; @@ -74,6 +74,21 @@ void main() { #define tile_row(r) (row_tid * rows_per_thread + (r)) + // Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK). + if ((HSK % 16) != 0) { + [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) { + if (i + tid < Br * qstride) { + Qf[i + tid] = f16vec4(0); + } + } + [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) { + if (i + tid < Bc * kshstride) { + ksh[i + tid] = f16vec4(0); + } + } + barrier(); + } + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { @@ -151,14 +166,14 @@ void main() { } barrier(); - // K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br + // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 // This is written transposed in order to allow for N being 8 if implementations need it coopmat SfMat = coopmat(0); coopmat KMat; coopmat QMat; - for (uint32_t d = 0; d < HSK / 16; ++d) { + for (uint32_t d = 0; d < HSK_pad / 16; ++d) { coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; @@ -358,6 +373,9 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] *= ACC_TYPE(Lfrcp[r]); +#if defined(ACC_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); +#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index b0564ca0bfc..ab647e9bc8b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -104,16 +104,16 @@ void main() { tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); - coopmat Q; - coopmat Qf16; + coopmat Q; + coopmat Qf16; uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; - coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK)); + coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); - Qf16 = coopmat(Q); + Qf16 = coopmat(Q); Qf16 *= float16_t(p.scale); - coopmat O = coopmat(0); + coopmat O = coopmat(0); coopmat L, M; @@ -140,10 +140,10 @@ void main() { coopmat S = coopmat(0); - coopmat K_T; + coopmat K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC); + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); S = coopMatMulAdd(Qf16, K_T, S); if (p.logit_softcap != 0.0f) { @@ -208,31 +208,31 @@ void main() { rowsum = coopmat(0.0); rowsum = coopMatMulAdd(P_A, One, rowsum); - coopmat V; + coopmat V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; - coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC); + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC); L = eM*L + rowsum; // This is the "diagonal" matrix in the paper, but since we do componentwise // multiply rather than matrix multiply it has the diagonal element smeared // across the row - coopmat eMdiag; + coopmat eMdiag; // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); // multiply with fp16 accumulation, then add to O. - coopmat PV = coopmat(0); + coopmat PV = coopmat(0); PV = coopMatMulAdd(P_A, V, PV); - O = eMdiag * O + coopmat(PV); + O = eMdiag * O + coopmat(PV); } // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - coopmat O_D = coopmat(O); + coopmat O_D = coopmat(O); uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); @@ -243,16 +243,16 @@ void main() { return; } - coopmat Ldiag; + coopmat Ldiag; // resize L by using smear/reduce coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { - coopmat S; + coopmat S; coopMatPerElementNV(S, S, perElemOpGetSink, iq2); - coopmat Mr; + coopmat Mr; // resize M by using smear/reduce coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce); @@ -283,9 +283,13 @@ void main() { O = Ldiag*O; +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; - coopmat O_D = coopmat(O); + coopmat O_D = coopmat(O); if (p.gqa_ratio > 1) { coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); } else { @@ -295,6 +299,6 @@ void main() { // permute dimensions tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); - coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute); + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 76ef4b6dfb5..06e83822fe3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -111,6 +111,10 @@ void main() { } } O *= L; + + const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF); + O = clamp(O, -FLT_MAX, FLT_MAX); + data_d[iq3 * D * N + D * n + d] = O; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp index ee6b86a18dd..7ef75cd7a49 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -7,27 +7,36 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; void main() { const uint i00 = gl_GlobalInvocationID.x; - const uint i10 = gl_GlobalInvocationID.y; - const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; - const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; if (i00 >= p.ne00) { return; } - const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + uint gid_z = gl_GlobalInvocationID.z; + while (gid_z < p.ne11 * p.ne12) { + uint gid_y = gl_GlobalInvocationID.y; + while (gid_y < p.ne10) { + const uint i10 = gid_y; + const uint i11 = gid_z / p.ne12; + const uint i12 = gid_z % p.ne12; - const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; - const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; #if defined(DATA_A_BF16) - FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); + FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); #else - FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); + FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); #endif #ifndef OPTIMIZATION_ERROR_WORKAROUND - data_d[d_offset + i00] = D_TYPE(v); + data_d[d_offset + i00] = D_TYPE(v); #else - data_d[d_offset + i00] = D_TYPE(v); + data_d[d_offset + i00] = D_TYPE(v); #endif + gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z; + } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index cfd645a38a8..339f905fc75 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -10,9 +10,6 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; void main() { const uint i00 = (gl_GlobalInvocationID.x)*2; - const uint i10 = gl_GlobalInvocationID.y; - const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; - const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -22,20 +19,33 @@ void main() { return; } - const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + uint gid_z = gl_GlobalInvocationID.z; + while (gid_z < p.ne11 * p.ne12) { + uint gid_y = gl_GlobalInvocationID.y; + while (gid_y < p.ne10) { + const uint i10 = gid_y; + const uint i11 = gid_z / p.ne12; + const uint i12 = gid_z % p.ne12; - const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; - const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; - const uint ib = a_offset + i00/QUANT_K; // block index - const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index - const uint iybs = i00 - i00%QUANT_K; // dst block start index - const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; - vec2 v = dequantize(ib, iqs, 0); - const vec2 dm = get_dm(ib, 0); - v = v * dm.x + dm.y; + const uint ib = a_offset + i00/QUANT_K; // block index + const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index + const uint iybs = i00 - i00%QUANT_K; // dst block start index + const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; - data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); - data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); + vec2 v = dequantize(ib, iqs, 0); + const vec2 dm = get_dm(ib, 0); + v = v * dm.x + dm.y; + + data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); + data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); + + gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z; + } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp new file mode 100644 index 00000000000..1da252cc663 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(min(1.0f, max(0.0f, (x + 3.0f) / 6.0f))); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp b/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp new file mode 100644 index 00000000000..3afc588274f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f))); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp new file mode 100644 index 00000000000..3b010bdeb5e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -0,0 +1,112 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "rte.comp" + +layout (push_constant) uniform parameter +{ + uint32_t nb10; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t s0; + uint32_t s1; + uint32_t s2; + uint32_t p0; + uint32_t p1; + uint32_t p2; + uint32_t d0; + uint32_t d1; + uint32_t d2; + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t IC; + uint32_t KW; + uint32_t OH; + uint32_t KD_KH_KW; + uint32_t KH_KW; + uint32_t IC_KD_KH_KW; + uint32_t N_OD_OH; + uint32_t OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW; + uint32_t misalign_offsets; +} p; + +#include "types.comp" + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint32_t i = gl_GlobalInvocationID.x; + + uint32_t nb10 = p.nb10; + uint32_t nb11 = p.nb11; + uint32_t nb12 = p.nb12; + uint32_t nb13 = p.nb13; + uint32_t s0 = p.s0; + uint32_t s1 = p.s1; + uint32_t s2 = p.s2; + uint32_t p0 = p.p0; + uint32_t p1 = p.p1; + uint32_t p2 = p.p2; + uint32_t d0 = p.d0; + uint32_t d1 = p.d1; + uint32_t d2 = p.d2; + uint32_t IW = p.IW; + uint32_t IH = p.IH; + uint32_t ID = p.ID; + uint32_t IC = p.IC; + uint32_t KW = p.KW; + uint32_t OH = p.OH; + uint32_t KD_KH_KW = p.KD_KH_KW; + uint32_t KH_KW = p.KH_KW; + uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW; + uint32_t N_OD_OH = p.N_OD_OH; + uint32_t OD_OH = p.OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW; + + if (i >= IC_KD_KH_KW) { + return; + } + + const uint32_t iic = i / KD_KH_KW; + const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; + const uint32_t ikw = i % KW; + + const uint32_t iow = gl_GlobalInvocationID.y; + for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) { + const uint32_t in_ = iz / OD_OH; + const uint32_t iod = (iz - in_*OD_OH) / OH; + const uint32_t ioh = iz % OH; + + const uint32_t iiw = iow * s0 + ikw * d0 - p0; + const uint32_t iih = ioh * s1 + ikh * d1 - p1; + const uint32_t iid = iod * s2 + ikd * d2 - p2; + + const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + + if (iih >= IH || iiw >= IW || iid >= ID) { + data_d[offset_dst + get_doffset()] = D_TYPE(0.0f); + } else { + const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10; + data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp index b93e9948f7f..f761391eaed 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp @@ -1,7 +1,8 @@ #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_8bit_storage : require -#if USE_SUBGROUP_ADD + +#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM #extension GL_KHR_shader_subgroup_basic : require #extension GL_KHR_shader_subgroup_arithmetic : require #endif @@ -12,10 +13,19 @@ #include "types.comp" +#ifndef MMQ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#else +layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#endif + layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +#ifdef B_TYPE_VEC2 layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; +#endif +#ifdef B_TYPE_VEC4 layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; +#endif layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID @@ -92,6 +102,23 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32; layout (constant_id = 1) const uint NUM_ROWS = 1; layout (constant_id = 2) const uint NUM_COLS = 1; +#ifdef USE_SUBGROUP_ADD_NO_SHMEM +void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = subgroupAdd(temp[j][n]); + } + } + + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); + } + } + } +} +#else shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { @@ -152,3 +179,4 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs } #endif } +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp new file mode 100644 index 00000000000..8fb314fa0aa --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -0,0 +1,140 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_integer_dot_product : require + +#define MMQ +#define B_TYPE block_q8_1_x4 + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#define K_PER_ITER 8 + +#include "mul_mmq_funcs.comp" + +uint a_offset, b_offset, d_offset; + +int32_t cache_b_qs[2]; +vec2 cache_b_ds; + +void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + tid*K_PER_ITER; + + // Preload data_b block + const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset; + const uint b_qs_idx = tid % 4; + const uint b_block_idx_outer = b_block_idx / 4; + const uint b_block_idx_inner = b_block_idx % 4; + cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]); + +#if QUANT_R == 2 + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4]; +#else + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1]; +#endif + + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint a_block_idx = (ibi + col)/QUANT_K + a_offset; + ibi += p.ncols; + + int32_t q_sum = 0; +#if QUANT_R == 2 + const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx); + q_sum += dotPacked4x8EXT(data_a_qs.x, + cache_b_qs[0]); + q_sum += dotPacked4x8EXT(data_a_qs.y, + cache_b_qs[1]); +#else + int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[0]); + data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[1]); +#endif + +#if QUANT_AUXF == 1 + temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4); +#else + temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4); +#endif + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + const uint tid = gl_LocalInvocationID.x; + + get_offsets(a_offset, b_offset, d_offset); + a_offset /= QUANT_K; + b_offset /= QUANT_K_Q8_1; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = FLOAT_TYPE(0.0f); + } + } + + uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); + if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { + num_iters++; + } + int unroll_count = 4; + uint unrolled_iters = num_iters & ~(unroll_count - 1); + + uint i = 0; + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + + unroll_count = 2; + unrolled_iters = num_iters & ~(unroll_count - 1); + +#if K_PER_ITER == 2 + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + while (i < num_iters) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index a61a464c7be..38a4d07d030 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -17,6 +17,9 @@ #ifdef COOPMAT #extension GL_KHR_cooperative_matrix : enable #extension GL_KHR_memory_scope_semantics : enable +#endif + +#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS) #extension GL_KHR_shader_subgroup_basic : enable #extension GL_KHR_shader_subgroup_ballot : enable #endif @@ -28,10 +31,10 @@ #include "types.comp" #ifndef LOAD_VEC_A -#define LOAD_VEC_A 1 +#define LOAD_VEC_A 2 #endif #ifndef LOAD_VEC_B -#define LOAD_VEC_B 1 +#define LOAD_VEC_B 2 #endif #if !defined(TO_FLOAT_TYPE) @@ -95,28 +98,93 @@ layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat layout (constant_id = 10) const uint WARP = 32; #ifdef COOPMAT -#define SHMEM_STRIDE (BK + 8) +#define SHMEM_STRIDE (BK / 2 + 4) #else -#define SHMEM_STRIDE (BK + 1) +#define SHMEM_STRIDE (BK / 2 + 1) #endif -shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; -shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; +shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE]; +shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; + +#define NUM_WARPS (BLOCK_SIZE / WARP) #ifdef MUL_MAT_ID -shared u16vec2 row_ids[4096]; +shared u16vec2 row_ids[BN]; uint _ne1; -#ifdef COOPMAT -shared uint _ne1_sh; -#endif -#endif // MUL_MAT_ID -#define NUM_WARPS (BLOCK_SIZE / WARP) +#ifdef MUL_MAT_ID_USE_SUBGROUPS +shared uvec4 ballots_sh[NUM_WARPS]; + +void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + uint nei0shift = findLSB(p.nei0); + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_LocalInvocationIndex; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + + ballots_sh[gl_SubgroupID] = ballot; + barrier(); + + uint subgroup_base = 0; + uint total = 0; + for (uint k = 0; k < gl_NumSubgroups; ++k) { + if (k == gl_SubgroupID) { + subgroup_base = total; + } + total += subgroupBallotBitCount(ballots_sh[k]); + } + barrier(); + + uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { + row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1); + } + _ne1 += total; + iter &= 15; + if (_ne1 >= (ic + 1) * BN) { + break; + } + } + barrier(); +} +#endif // MUL_MAT_ID_USE_SUBGROUPS +#endif // MUL_MAT_ID #ifdef COOPMAT shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif +#include "mul_mm_funcs.comp" + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -177,51 +245,20 @@ void main() { const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; #ifdef MUL_MAT_ID -#ifdef COOPMAT - // Spread the search across all elements in the first subgroup - if (gl_SubgroupID == 0) { - _ne1 = 0; - uint num_elements = p.nei1 * p.nei0; - - uint ids[16]; - uint iter = 0; - - for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { - // prefetch up to 16 elements - if (iter == 0) { - [[unroll]] for (uint k = 0; k < 16; ++k) { - uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; - bool in_range = i < num_elements; - uint ii1 = i / p.nei0; - uint ii0 = i % p.nei0; - ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; - } - } - uint i = j + gl_SubgroupInvocationID; - bool in_range = i < num_elements; - uint ii1 = i / p.nei0; - uint ii0 = i % p.nei0; - uint id = ids[iter++]; - uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - uint idx = subgroupBallotExclusiveBitCount(ballot); - if (in_range && id == expert_idx) { - row_ids[_ne1 + idx] = u16vec2(ii0, ii1); - } - _ne1 += subgroupBallotBitCount(ballot); - iter &= 15; - } - _ne1_sh = _ne1; +#ifdef MUL_MAT_ID_USE_SUBGROUPS + if (bitCount(p.nei0) == 1) { + load_row_ids(expert_idx, true, ic); + } else { + load_row_ids(expert_idx, false, ic); } - - barrier(); - - _ne1 = _ne1_sh; #else _ne1 = 0; - for (uint ii1 = 0; ii1 < p.nei1; ii1++) { - for (uint ii0 = 0; ii0 < p.nei0; ii0++) { + for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { - row_ids[_ne1] = u16vec2(ii0, ii1); + if (_ne1 >= ic * BN) { + row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1); + } _ne1++; } } @@ -265,8 +302,8 @@ void main() { } #else ACC_TYPE sums[WMITER * TM * WNITER * TN]; - FLOAT_TYPE cache_a[WMITER * TM]; - FLOAT_TYPE cache_b[TN]; + FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; + FLOAT_TYPE_VEC2 cache_b[TN]; [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = ACC_TYPE(0.0f); @@ -275,538 +312,13 @@ void main() { for (uint block = start_k; block < end_k; block += BK) { [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { - -#if defined(DATA_A_F32) || defined(DATA_A_F16) -#if LOAD_VEC_A == 8 - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w); - buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x); - buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y); - buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z); - buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); -#elif LOAD_VEC_A == 4 - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); -#else - if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { - buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); - } else { - buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f); - } -#endif -#elif defined(DATA_A_BF16) -#if LOAD_VEC_A == 4 - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x); - buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y); - buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z); - buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w); -#else - if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { - buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); - } else { - buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0)); - } -#endif -#elif defined(DATA_A_Q4_0) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; - - const uint ib = idx / 4; - const uint iqs = idx & 0x03; - - const float d = float(data_a_packed16[ib].d); - const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); - const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; - const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; - - buf_a[buf_idx ] = FLOAT_TYPE(v0.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); - buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); - buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); - buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); - buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); - buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); - buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); -#elif defined(DATA_A_Q4_1) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; - - const uint ib = idx / 4; - const uint iqs = idx & 0x03; - - const float d = float(data_a_packed16[ib].d); - const float m = float(data_a_packed16[ib].m); - const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); - const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; - const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; - - buf_a[buf_idx ] = FLOAT_TYPE(v0.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); - buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); - buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); - buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); - buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); - buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); - buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); -#elif defined(DATA_A_Q5_0) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - - const uint ib = idx / 8; - const uint iqs = idx & 0x07; - - const float d = float(data_a_packed16[ib].d); - const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]); - const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); - const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); - - const uint vui = uint(data_a_packed16[ib].qs[iqs]); - const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); - buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); -#elif defined(DATA_A_Q5_1) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - - const uint ib = idx / 8; - const uint iqs = idx & 0x07; - - const float d = float(data_a_packed16[ib].d); - const float m = float(data_a_packed16[ib].m); - const uint uint_qh = data_a_packed16[ib].qh; - const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); - const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); - - const uint vui = uint(data_a_packed16[ib].qs[iqs]); - const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); - buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); -#elif defined(DATA_A_Q8_0) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 8; - const uint iqs = idx & 0x07; - - const float d = float(data_a_packed16[ib].d); - const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; - const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); - buf_a[buf_idx + 2] = FLOAT_TYPE(v.z); - buf_a[buf_idx + 3] = FLOAT_TYPE(v.w); -#elif defined(DATA_A_Q2_K) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 - const uint scalesi = iqs / 8; // 0..15 - const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 - - const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); - const uint scales = data_a[ib].scales[scalesi]; - const vec2 d = vec2(data_a[ib].d); - - const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); -#elif defined(DATA_A_Q3_K) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint n = iqs / 64; // 0,1 - const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 - const uint hmi = (iqs % 16) * 2; // 0,2,4..30 - const uint j = (iqs % 64) / 4; // 0..3 - const uint is = iqs / 8; // 0..15 - const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 - const uint qsshift = halfsplit * 2; // 0,2,4,6 - const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 - - const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) - | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); - const float dl = float(data_a[ib].d) * float(us - 32); - - buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); - buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); -#elif defined(DATA_A_Q4_K) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint n = iqs / 32; // 0,1,2,3 - const uint b = (iqs % 32) / 16; // 0,1 - const uint is = 2 * n + b; // 0..7 - const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 - - const vec2 loadd = vec2(data_a[ib].d); - - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; - - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); - - const float d = loadd.x * sc; - const float m = -loadd.y * mbyte; - - buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m)); - buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); -#elif defined(DATA_A_Q5_K) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint n = iqs / 32; // 0,1,2,3 - const uint b = (iqs % 32) / 16; // 0,1 - const uint is = 2 * n + b; // 0..7 - const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 - const uint qhi = (iqs % 16) * 2; // 0,2,4..30 - - const uint8_t hm = uint8_t(1 << (iqs / 16)); - - const vec2 loadd = vec2(data_a[ib].d); - - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; - - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); - - const float d = loadd.x * sc; - const float m = -loadd.y * mbyte; - - buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m)); - buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); -#elif defined(DATA_A_Q6_K) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 - - const uint n = iqs / 64; // 0,1 - const uint b = (iqs % 64) / 32; // 0,1 - const uint is_b = (iqs % 16) / 8; // 0,1 - const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 - const uint is = 8 * n + qhshift + is_b; // 0..15 - const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 - const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 - - const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); - - buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); - buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); -#elif defined(DATA_A_IQ1_S) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 32; // 8 values per idx - const uint ib32 = (idx % 32) / 4; // 0..7 - const uint ib8 = idx % 32; - - const float d = float(data_a[ib].d); - const uint qh = data_a[ib].qh[ib32]; - const uint qs = data_a[ib].qs[ib8]; - const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1); - const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; - const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); - - [[unroll]] for (int k = 0; k < 8; ++k) { - buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); - } -#elif defined(DATA_A_IQ1_M) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 32; // 8 values per idx - const uint ib8 = idx % 32; - const uint ib16 = ib8 / 2; - - const uint16_t[4] scales = data_a[ib].scales; - const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; - const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); - const uint sc = scales[ib8 / 8]; - const uint qs = data_a[ib].qs[ib8]; - const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1)); - const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); - const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; - const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); - - [[unroll]] for (int k = 0; k < 8; ++k) { - buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); - } -#elif defined(DATA_A_IQ2_XXS) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 32; // 8 values per idx - const uint ib32 = (idx % 32) / 4; // 0..7 - const uint ib8 = idx % 4; - - const float d = float(data_a[ib].d); - const uint qs = data_a[ib].qs[8 * ib32 + ib8]; - const uint signs = pack32(u8vec4( - data_a[ib].qs[8*ib32 + 4], - data_a[ib].qs[8*ib32 + 5], - data_a[ib].qs[8*ib32 + 6], - data_a[ib].qs[8*ib32 + 7] - )); - const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28))); - const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); - const uint sign = sign7 | (bitCount(sign7) << 7); - const uvec2 grid = iq2xxs_grid[qs]; - const vec4 grid0 = vec4(unpack8(grid.x)); - const vec4 grid1 = vec4(unpack8(grid.y)); - - buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); - buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); - buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); - buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); - buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); -#elif defined(DATA_A_IQ2_XS) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 32; // 8 values per idx - const uint ib32 = (idx % 32) / 4; // 0..7 - const uint ib8 = idx % 4; // 0..3 - - const float d = float(data_a[ib].d); - const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; - const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); - const uint qs = data_a[ib].qs[4 * ib32 + ib8]; - const uint sign7 = qs >> 9; - const uint sign = sign7 | (bitCount(sign7) << 7); - const uvec2 grid = iq2xs_grid[qs & 511]; - const vec4 grid0 = vec4(unpack8(grid.x)); - const vec4 grid1 = vec4(unpack8(grid.y)); - - buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); - buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); - buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); - buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); - buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); -#elif defined(DATA_A_IQ2_S) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 32; // 8 values per idx - const uint ib8 = idx % 32; // 0..31 - const uint ib32 = ib8 / 4; // 0..7 - - const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; - const uint qs = data_a[ib].qs[ib8]; - const uint qh = data_a[ib].qh[ib32]; - const uint qhshift = 2 * (ib8 % 4); - const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]; - - const float d = float(data_a[ib].d); - const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); - const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)]; - const vec4 grid0 = vec4(unpack8(grid.x)); - const vec4 grid1 = vec4(unpack8(grid.y)); - - buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); - buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); - buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); - buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); - buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); -#elif defined(DATA_A_IQ3_XXS) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 64; // 4 values per idx - const uint iqs = idx % 64; // 0..63 - const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values - - const float d = float(data_a[ib].d); - const uint qs = data_a[ib].qs[iqs]; - const uint signs = pack32(u8vec4( - data_a[ib].qs[is+0], - data_a[ib].qs[is+1], - data_a[ib].qs[is+2], - data_a[ib].qs[is+3] - )); - const float db = d * 0.5 * (0.5 + (signs >> 28)); - const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2)); - const uint grid = iq3xxs_grid[qs]; - const vec4 v = db * vec4(unpack8(grid)); - - buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); - buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); - buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); -#elif defined(DATA_A_IQ3_S) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 64; // 4 values per idx - const uint iqs = idx % 64; // 0..63 - const uint iqh = iqs / 8; - - const float d = float(data_a[ib].d); - const uint qs = data_a[ib].qs[iqs]; - const uint qh = data_a[ib].qh[iqh]; - const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2))); - const uint scale = data_a[ib].scales[iqs / 16]; - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); - const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); - const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; - const vec4 v = db * vec4(unpack8(grid)); - - buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); - buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); - buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); -#elif defined(DATA_A_IQ4_XS) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint iq = 16 * ib32 + 2 * (idx % 8); - - const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; - const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; - const uint qshift = (idx & 8) >> 1; - u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]); - qs = (qs >> qshift) & uint8_t(0xF); - - const float d = float(data_a[ib].d); - const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); -#elif defined(DATA_A_IQ4_NL) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - - const uint ib = idx / 8; - const uint iqs = idx & 0x07; - - const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); - const uint vui = uint(data_a_packed16[ib].qs[iqs]); - - buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d; - buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; - buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; - buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d; -#elif defined(DATA_A_MXFP4) - const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - - const uint ib = idx / 8; - const uint iqs = (idx & 0x07) * 2; - - const float d = e8m0_to_fp32(data_a[ib].e); - const uint vui = uint(data_a[ib].qs[iqs]); - const uint vui2 = uint(data_a[ib].qs[iqs+1]); - - buf_a[buf_idx ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d); - buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >> 4] * d); - buf_a[buf_idx + 1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d); - buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >> 4] * d); -#endif + load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k); } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { -#if LOAD_VEC_B == 8 -#ifdef MUL_MAT_ID - const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; - const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; +#if !defined(MUL_MAT_ID) + load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k); #else - const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; -#endif - const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w); - buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x); - buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y); - buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); - buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); -#elif LOAD_VEC_B == 4 -#ifdef MUL_MAT_ID - const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; - const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; -#else - const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; -#endif - const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x); - buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y); - buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z); - buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w); -#elif !MUL_MAT_ID - if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); - } else { - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); - } -#else - const uint row_i = ic * BN + loadc_b + l; - if (row_i < _ne1 && block + loadr_b < end_k) { - const u16vec2 row_idx = row_ids[row_i]; - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); - } else { - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); - } + load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block, end_k); #endif } @@ -819,17 +331,17 @@ void main() { [[unroll]] for (uint i = 0; i < BK; i += TK) { [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { // Load from shared into cache - coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); } } } #else - [[unroll]] for (uint i = 0; i < BK; i++) { + [[unroll]] for (uint i = 0; i < BK / 2; i++) { // Load from shared into cache [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint j = 0; j < TM; j++) { @@ -845,7 +357,7 @@ void main() { [[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]); + sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx])); } } } @@ -856,6 +368,20 @@ void main() { barrier(); } +#if defined(ACC_TYPE_MAX) +#ifdef COOPMAT + [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) { + [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) { + sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + } + } +#else + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + } +#endif +#endif + const uint dr = ir * BM + warp_r * WM; const uint dc = ic * BN + warp_c * WN; @@ -873,7 +399,7 @@ void main() { const uint row_i = dc + cm_col * TN + col + store_c; if (row_i >= _ne1) break; - const u16vec2 row_idx = row_ids[row_i]; + const u16vec2 row_idx = row_ids[row_i - ic * BN]; if (dr + cm_row * TM + store_r < p.M) { data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); @@ -923,7 +449,7 @@ void main() { const uint row_i = dc_warp + cc; if (row_i >= _ne1) break; - const u16vec2 row_idx = row_ids[row_i]; + const u16vec2 row_idx = row_ids[row_i - ic * BN]; #endif // MUL_MAT_ID [[unroll]] for (uint cr = 0; cr < TM; cr++) { #ifdef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 29e4b5c9ce2..69ac38fd419 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -19,6 +19,7 @@ #endif #include "types.comp" +#include "utils.comp" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -92,14 +93,15 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; -shared u16vec4 row_ids[4096]; +shared u16vec4 row_ids[BN]; layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { B_TYPE b[]; }; uint _ne1; -shared uint _ne1_sh; +layout (constant_id = 5) const uint subgroup_size = 32; +shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size]; B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { @@ -109,7 +111,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i return B_TYPE(0.0); } - const u16vec4 row_idx = row_ids[row_i]; + const u16vec4 row_idx = row_ids[row_i & (BN - 1)]; B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; return ret; @@ -121,13 +123,74 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem uint dc = ic * BN + c; if (dr < p.M && dc < _ne1) { - uint row_i = dc; + uint row_i = c; const u16vec4 row_idx = row_ids[row_i]; data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem; } return elem; } +void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + uint nei0shift = findLSB(p.nei0); + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_LocalInvocationIndex; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + + ballots_sh[gl_SubgroupID] = ballot; + barrier(); + + uint subgroup_base = 0; + uint total = 0; + for (uint k = 0; k < gl_NumSubgroups; ++k) { + if (k == gl_SubgroupID) { + subgroup_base = total; + } + total += subgroupBallotBitCount(ballots_sh[k]); + } + barrier(); + + uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { + row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0); + } + _ne1 += total; + iter &= 15; + if (_ne1 >= (ic + 1) * BN) { + break; + } + } + barrier(); +} #endif void main() { @@ -157,45 +220,12 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - // Spread the search across all elements in the first subgroup - if (gl_SubgroupID == 0) { - _ne1 = 0; - uint num_elements = p.nei1 * p.nei0; - - uint ids[16]; - uint iter = 0; - - for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { - // prefetch up to 16 elements - if (iter == 0) { - [[unroll]] for (uint k = 0; k < 16; ++k) { - uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; - bool in_range = i < num_elements; - uint ii1 = i / p.nei0; - uint ii0 = i % p.nei0; - ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; - } - } - uint i = j + gl_SubgroupInvocationID; - bool in_range = i < num_elements; - uint ii1 = i / p.nei0; - uint ii0 = i % p.nei0; - uint id = ids[iter++]; - uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - uint idx = subgroupBallotExclusiveBitCount(ballot); - if (in_range && id == expert_idx) { - row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); - } - _ne1 += subgroupBallotBitCount(ballot); - iter &= 15; - } - _ne1_sh = _ne1; + if (bitCount(p.nei0) == 1) { + load_row_ids(expert_idx, true, ic); + } else { + load_row_ids(expert_idx, false, ic); } - barrier(); - - _ne1 = _ne1_sh; - // Workgroup has no work if (ic * BN >= _ne1) return; #endif @@ -319,6 +349,10 @@ void main() { sum = coopMatMulAdd(mat_a, mat_b, sum); block_k += BK; } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + coopmat mat_d = coopmat(sum); coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose); @@ -358,6 +392,10 @@ void main() { sum = coopMatMulAdd(mat_a, mat_b, sum); block_k += BK; } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + coopmat mat_d = coopmat(sum); coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose); @@ -398,6 +436,10 @@ void main() { sum = coopMatMulAdd(mat_a, mat_b, sum); block_k += BK; } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + coopmat mat_d = coopmat(sum); coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); @@ -414,18 +456,111 @@ void main() { tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); - coopmat sum; - sum = coopmat(0.0); - uint k_iters = (end_k - start_k + BK - 1) / BK; fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false); + store_scales(tid); + +#ifdef MUL_MAT_ID + if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) { + coopmat sum; + sum = coopmat(0.0); + + [[dont_unroll]] + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + if ((block_k % QUANT_K) == 0) { + store_scales(tid); + } + if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); + } + + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); + return; + } + if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) { + coopmat sum; + sum = coopmat(0.0); + + [[dont_unroll]] + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + if ((block_k % QUANT_K) == 0) { + store_scales(tid); + } + if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); + } + + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); + return; + } +#endif + coopmat sum; + sum = coopmat(0.0); [[dont_unroll]] for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { - store_scales(tid); - if (block_k + BK < end_k) { + if ((block_k % QUANT_K) == 0) { + store_scales(tid); + } + if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) { fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); } @@ -455,6 +590,9 @@ void main() { sum = coopMatMulAdd(mat_a, mat_b, sum); } } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif // Convert from ACC_TYPE to D_TYPE coopmat mat_d; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp new file mode 100644 index 00000000000..69d0e64c35d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp @@ -0,0 +1,556 @@ +void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) { +#if defined(DATA_A_F32) || defined(DATA_A_F16) +#if LOAD_VEC_A == 8 + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]); + buf_a[buf_idx ] = aa[0].xy; + buf_a[buf_idx + 1] = aa[0].zw; + buf_a[buf_idx + 2] = aa[1].xy; + buf_a[buf_idx + 3] = aa[1].zw; +#elif LOAD_VEC_A == 4 + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; +#else // LOAD_VEC_A == 2 + const uint idx = pos_a * 2 + col * p.stride_a + row * 2; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (idx_m < p.M && block + row * 2 + 1 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], + data_a[idx + 1]); + } else if (idx_m < p.M && block + row * 2 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f); + } else { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +#elif defined(DATA_A_BF16) +#if LOAD_VEC_A == 4 + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx])); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; +#else // LOAD_VEC_A == 2 + const uint idx = pos_a * 2 + col * p.stride_a + row * 2; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (idx_m < p.M && block + row * 2 + 1 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), + TO_FLOAT_TYPE(data_a[idx + 1])); + } else if (idx_m < p.M && block + row * 2 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); + } else { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +#elif defined(DATA_A_Q4_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; + const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy); + buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); +#elif defined(DATA_A_Q4_1) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); + buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw); + buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy); + buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw); +#elif defined(DATA_A_Q5_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]); + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); +#elif defined(DATA_A_Q5_1) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint uint_qh = data_a_packed16[ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); +#elif defined(DATA_A_Q8_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; + const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); +#elif defined(DATA_A_Q2_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint scalesi = iqs / 8; // 0..15 + const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + + const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); + const uint scales = data_a[ib].scales[scalesi]; + const vec2 d = vec2(data_a[ib].d); + + const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); +#elif defined(DATA_A_Q3_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint hmi = (iqs % 16) * 2; // 0,2,4..30 + const uint j = (iqs % 64) / 4; // 0..3 + const uint is = iqs / 8; // 0..15 + const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) + | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); + const float dl = float(data_a[ib].d) * float(us - 32); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)), + dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); +#elif defined(DATA_A_Q4_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m), + fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); +#elif defined(DATA_A_Q5_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + const uint qhi = (iqs % 16) * 2; // 0,2,4..30 + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m), + fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); +#elif defined(DATA_A_Q6_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint b = (iqs % 64) / 32; // 0,1 + const uint is_b = (iqs % 16) / 8; // 0,1 + const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + const uint is = 8 * n + qhshift + is_b; // 0..15 + const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32), + dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +#elif defined(DATA_A_IQ1_S) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 32; + + const float d = float(data_a[ib].d); + const uint qh = data_a[ib].qh[ib32]; + const uint qs = data_a[ib].qs[ib8]; + const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + + [[unroll]] for (int k = 0; k < 4; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + } +#elif defined(DATA_A_IQ1_M) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; + const uint ib16 = ib8 / 2; + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + const uint sc = scales[ib8 / 8]; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + + [[unroll]] for (int k = 0; k < 4; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + } +#elif defined(DATA_A_IQ2_XXS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[8 * ib32 + ib8]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[8*ib32 + 4], + data_a[ib].qs[8*ib32 + 5], + data_a[ib].qs[8*ib32 + 6], + data_a[ib].qs[8*ib32 + 7] + )); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28))); + const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xxs_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); +#elif defined(DATA_A_IQ2_XS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; // 0..3 + + const float d = float(data_a[ib].d); + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uint qs = data_a[ib].qs[4 * ib32 + ib8]; + const uint sign7 = qs >> 9; + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xs_grid[qs & 511]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); +#elif defined(DATA_A_IQ2_S) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 + + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]; + + const float d = float(data_a[ib].d); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); +#elif defined(DATA_A_IQ3_XXS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 + const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[is+0], + data_a[ib].qs[is+1], + data_a[ib].qs[is+2], + data_a[ib].qs[is+3] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2)); + const uint grid = iq3xxs_grid[qs]; + const vec4 v = db * vec4(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); +#elif defined(DATA_A_IQ3_S) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 + const uint iqh = iqs / 8; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint qh = data_a[ib].qh[iqh]; + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2))); + const uint scale = data_a[ib].scales[iqs / 16]; + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const vec4 v = db * vec4(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); +#elif defined(DATA_A_IQ4_XS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint iq = 16 * ib32 + 2 * (idx % 8); + + const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; + const uint qshift = (idx & 8) >> 1; + u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]); + qs = (qs >> qshift) & uint8_t(0xF); + + const float d = float(data_a[ib].d); + const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); +#elif defined(DATA_A_IQ4_NL) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + + buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF], + kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]); + buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)], + kvalues_iq4nl[vui >> 12]); +#elif defined(DATA_A_MXFP4) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = (idx & 0x07) * 2; + + const float d = e8m0_to_fp32(data_a[ib].e); + const uint vui = uint(data_a[ib].qs[iqs]); + const uint vui2 = uint(data_a[ib].qs[iqs+1]); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); +#endif +} + +#if !defined(MUL_MAT_ID) +void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) { +#if LOAD_VEC_B == 8 + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; +#elif LOAD_VEC_B == 4 + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; +#if defined(DATA_B_BF16) + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); +#else + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; +#else // LOAD_VEC_B == 2 + const uint idx = pos_b * 2 + col * p.stride_b + row * 2; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (idx_n < p.N && block + row * 2 + 1 < end_k) { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); + } else if (idx_n < p.N && block + row * 2 < end_k) { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + } else { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +} +#else +void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) { +#if LOAD_VEC_B == 8 + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; +#elif LOAD_VEC_B == 4 + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; +#if defined(DATA_B_BF16) + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); +#else + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; +#else // LOAD_VEC_B == 2 + const uint row_i = ic * BN + col; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (row_i < _ne1 && block + row * 2 + 1 < end_k) { + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); + } else if (row_i < _ne1 && block + row * 2 < end_k) { + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + } else { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 83de90eb7e0..f36add62a9e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -28,7 +28,7 @@ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; #if defined(A_TYPE_PACKED32) layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif -layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];}; +layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID @@ -98,7 +98,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; #endif #define LOAD_VEC_A (4 * QUANT_R) -#define LOAD_VEC_B 4 +#define LOAD_VEC_B 16 #ifdef MUL_MAT_ID shared u16vec2 row_ids[4096]; @@ -270,15 +270,22 @@ void main() { const uint iqs = idx & 0x7; #else const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; + const uint ib_outer = ib / 4; + const uint ib_inner = ib % 4; + const uint iqs = loadr_b; #endif const uint buf_ib = loadc_b + l; if (iqs == 0) { - buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds); + buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); } - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs]; + const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w; } barrier(); @@ -349,7 +356,7 @@ void main() { cache_b_qs[cc * (BK / 4) + idx_k]); } - sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]); + sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp index 34e8db97704..cdfb230f4e7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp @@ -16,8 +16,8 @@ i32vec2 repack(uint ib, uint iqs) { (vui >> 4) & 0x0F0F0F0F); } -ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y)); +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); } #endif @@ -29,8 +29,8 @@ i32vec2 repack(uint ib, uint iqs) { (vui >> 4) & 0x0F0F0F0F); } -ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); } #endif @@ -50,8 +50,8 @@ i32vec2 repack(uint ib, uint iqs) { return i32vec2(v0, v1); } -ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y)); +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); } #endif @@ -69,8 +69,8 @@ i32vec2 repack(uint ib, uint iqs) { return i32vec2(v0, v1); } -ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); } #endif @@ -81,7 +81,7 @@ int32_t repack(uint ib, uint iqs) { data_a[ib].qs[iqs * 2 + 1])); } -ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(float(q_sum) * da * dsb.x); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp index 0c7acb7060f..854a2ad8187 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp @@ -3,6 +3,10 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_nonuniform_qualifier : enable #extension GL_EXT_control_flow_attributes : require +#if ADD_RMS +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif #include "rte.comp" #include "types.comp" @@ -14,11 +18,18 @@ layout (push_constant) uniform parameter2 uint ne20; uint ne21; uint ne22; uint ne23; // strides for srcs+dst - uint nb[8][4]; + uint nb[12][4]; + + uint rms_partials; } p; -layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[]; -layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[]; +// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498 +// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[]; +// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[]; +layout (binding = 0) buffer A {A_TYPE data_a[];} a[]; +layout (binding = 0) buffer D {D_TYPE data_d[];} d[]; + +layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[]; layout(constant_id = 0) const uint num_srcs = 2; @@ -42,14 +53,22 @@ const uint num_threads = 256; layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; +#if ADD_RMS +// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant +shared FLOAT_TYPE sumsh[num_threads]; +#endif + void main() { uint idx = get_idx(); + uint orig_idx = idx; uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23; // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation const uint num_iter = 2; + FLOAT_TYPE sum_sq = 0; + [[unroll]] for (uint i = 0; i < num_iter; ++i) { if (idx >= ne) { continue; @@ -61,8 +80,32 @@ void main() { [[unroll]] for (uint s = 0; s < num_srcs; ++s) { sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]); } + sum_sq += sum*sum; d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum); idx += num_threads; } + +#if ADD_RMS + if (p.rms_partials != 0) { + // reduce the sum within each subgroup, then across subgroups + const uint NumSubgroups = num_threads / gl_SubgroupSize; + sum_sq = subgroupAdd(sum_sq); + if (gl_SubgroupInvocationID == 0) { + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) { + if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) { + sum_sq += sumsh[gl_SubgroupID + s]; + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + } + + if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) { + partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq; + } + } +#endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp index 450b67fc55d..0d81220c71c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -1,7 +1,25 @@ #version 450 #include "types.comp" -#include "generic_unary_head.comp" + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint misalign_offsets; + + uint lp0; uint rp0; + uint lp1; uint rp1; + uint lp2; uint rp2; + uint lp3; uint rp3; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; @@ -19,10 +37,13 @@ void main() { const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; - const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00; const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; - const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 && + i1 >= p.lp1 && i1 < p.ne11 - p.rp1 && + i2 >= p.lp2 && i2 < p.ne12 - p.rp2 && + i3 >= p.lp3 && i3 < p.ne13 - p.rp3; data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp index e2e020fec2c..145c9fbdc9f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -3,6 +3,15 @@ #extension GL_EXT_control_flow_attributes : require #extension GL_EXT_shader_16bit_storage : require +#ifdef USE_SUBGROUPS +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_clustered : require + +#define INVOCATION_ID gl_SubgroupInvocationID.x +#else +#define INVOCATION_ID gl_LocalInvocationID.x +#endif + layout (push_constant) uniform parameter { uint ne; @@ -14,13 +23,19 @@ layout(constant_id = 0) const uint GROUP_SIZE = 32; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {vec4 data_a[];}; +#ifndef QBLOCK_X4 layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];}; +#else +layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];}; +#endif +#ifndef USE_SUBGROUPS shared float shmem[GROUP_SIZE]; +#endif void quantize() { const uint wgid = gl_WorkGroupID.x; - const uint tid = gl_LocalInvocationID.x; + const uint tid = INVOCATION_ID; // Each thread handles a vec4, so 8 threads handle a block const uint blocks_per_group = GROUP_SIZE / 8; @@ -30,9 +45,19 @@ void quantize() { const uint ib = wgid * blocks_per_group + block_in_wg; const uint iqs = tid % 8; +#ifndef QBLOCK_X4 if (ib >= gl_NumWorkGroups.x * blocks_per_group) { return; } +#else + const uint ibx4_outer = ib / 4; + const uint ibx4_inner = ib % 4; + + const uint required_x4_blocks = (p.ne + 127) / 128; + if (ibx4_outer >= required_x4_blocks) { + return; + } +#endif const uint a_idx = ib * 8 + iqs; @@ -40,7 +65,9 @@ void quantize() { const vec4 abs_vals = abs(vals); // Find absolute max for each block - shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); + const float thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); +#ifndef USE_SUBGROUPS + shmem[tid] = thread_max; barrier(); [[unroll]] for (uint s = 4; s > 0; s >>= 1) { if (iqs < s) { @@ -50,14 +77,28 @@ void quantize() { } const float amax = shmem[block_in_wg * 8]; +#else + const float amax = subgroupClusteredMax(thread_max, 8); +#endif + const float d = amax / 127.0; const float d_inv = d != 0.0 ? 1.0 / d : 0.0; vals = round(vals * d_inv); + +#ifndef QBLOCK_X4 data_b[ib].qs[iqs] = pack32(i8vec4(round(vals))); +#else + data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals))); +#endif + +#ifndef USE_SUBGROUPS barrier(); +#endif // Calculate the sum for each block - shmem[tid] = vals.x + vals.y + vals.z + vals.w; + const float thread_sum = vals.x + vals.y + vals.z + vals.w; +#ifndef USE_SUBGROUPS + shmem[tid] = thread_sum; barrier(); [[unroll]] for (uint s = 4; s > 0; s >>= 1) { if (iqs < s) { @@ -65,10 +106,19 @@ void quantize() { } barrier(); } +#else + const float sum = subgroupClusteredAdd(thread_sum, 8); +#endif if (iqs == 0) { +#ifndef USE_SUBGROUPS const float sum = shmem[tid]; +#endif +#ifndef QBLOCK_X4 data_b[ib].ds = f16vec2(vec2(d, sum * d)); +#else + data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d)); +#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index bdd7db2d698..41197e9301a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -10,9 +10,9 @@ layout (constant_id = 1) const bool do_multiply = false; layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE sum[BLOCK_SIZE]; +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; -void main() { +void rms_norm(uint num_iters) { const uint ncols = p.ne00; const uint nrows = gl_NumWorkGroups.x; const uint nchannels = gl_NumWorkGroups.y; @@ -30,38 +30,76 @@ void main() { uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); - sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]); - sum[tid] += xi * xi; + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + FLOAT_TYPE xi = FLOAT_TYPE(0); + if (col < ncols) { + xi = FLOAT_TYPE(data_a[a_offset + col]); + } + sum += xi * xi; } + sumsh[tid] = sum; // sum up partial sums and write back result barrier(); [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { if (tid < s) { - sum[tid] += sum[tid + s]; + sum += sumsh[tid + s]; + sumsh[tid] = sum; } barrier(); } + sum = sumsh[0]; - const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); + const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols); const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); if (do_multiply) { if (ncols > p.ne10) { - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); } } else { - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); } } } else { - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); } } } + +void main() { + // instantiate the rms_norm function for several different + // dimensions, to allow loop unrolling + uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE; + if (num_blocks > 32) { + rms_norm(num_blocks); + } else if (num_blocks > 16) { + rms_norm(32); + } else if (num_blocks > 8) { + rms_norm(16); + } else if (num_blocks > 4) { + rms_norm(8); + } else if (num_blocks == 4) { + rms_norm(4); + } else if (num_blocks == 3) { + rms_norm(3); + } else if (num_blocks == 2) { + rms_norm(2); + } else if (num_blocks == 1) { + rms_norm(1); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp new file mode 100644 index 00000000000..ba4677c2933 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp @@ -0,0 +1,65 @@ +#version 450 + +#include "generic_binary_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +#define BLOCK_SIZE 128 + +layout (constant_id = 1) const bool do_multiply = false; + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];}; + +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void main() { + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = 0; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + // The work is split across multiple workgroups in the x dimension. Each invocation + // processes one element + const uint tid = gl_GlobalInvocationID.x; + + const uint stride_row = p.nb01; + const uint stride_channel = p.nb02; + const uint stride_sample = p.nb03; + + uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + uint32_t num_partials = p.param3; + for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) { + sum += partial_sums[i]; + } + sum = subgroupAdd(sum); + + uint col = tid; + if (col >= ncols) { + return; + } + + const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); + + if (do_multiply) { + if (ncols > p.ne10) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + } else { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } + } else { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp index 29bd77d7e1c..144ea58e6fa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp @@ -20,6 +20,10 @@ void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + if (row >= p.KY) { + return; + } + FLOAT_TYPE scale = p.param1; // partial sums for thread in warp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp index 961e5ffa1f5..759204afaf9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -1,9 +1,9 @@ #version 450 -#include "generic_head.comp" #include "types.comp" #extension GL_EXT_control_flow_attributes : enable + layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; @@ -11,16 +11,49 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (constant_id = 0) const uint BLOCK_SIZE = 32; +layout (push_constant) uniform parameter +{ + uint n_cols; + uint ne01, ne02; + uint nb01, nb02, nb03; + uint nb11, nb12, nb13; + float weight; + uint misalign_offsets; + uint ne0_12mp, ne0_12L; + uint ne0_1mp, ne0_1L; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + + shared FLOAT_TYPE tmp[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint col = gl_LocalInvocationID.x; + const float weight = p.weight; + + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; + + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; - tmp[col] = FLOAT_TYPE(0.0f); + tmp[col] = FLOAT_TYPE(0.0); - for (uint i = col; i < p.KX; i += BLOCK_SIZE) { - tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]); + for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) { + tmp[col] += FLOAT_TYPE(data_a[src_idx + i]); } barrier(); @@ -32,6 +65,6 @@ void main() { } if (col == 0) { - data_d[row] = D_TYPE(tmp[0]); + data_d[dst_idx] = D_TYPE(tmp[0] * weight); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp index 79e065a9313..ce8e09442d9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp @@ -24,11 +24,12 @@ void main() { const uint j = gl_GlobalInvocationID.x; const uint d_offset = i * p.nb1; - if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) { - data_d[d_offset + p.dim] = 0.f; + const uint half_dim = p.dim / 2; + + if (p.dim % 2 != 0 && j == half_dim) { + data_d[d_offset + 2 * half_dim] = 0.f; } - const uint half_dim = p.dim / 2; if (j >= half_dim) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index a36c33e2676..75aa22eae40 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -11,12 +11,12 @@ #define QUANT_K 1 #define QUANT_R 1 -#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 -#define A_TYPE float -#elif LOAD_VEC_A == 4 +#if LOAD_VEC_A == 4 #define A_TYPE vec4 #elif LOAD_VEC_A == 8 #define A_TYPE mat2x4 +#else +#define A_TYPE float #endif #endif @@ -24,12 +24,12 @@ #define QUANT_K 1 #define QUANT_R 1 -#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 -#define A_TYPE float16_t -#elif LOAD_VEC_A == 4 +#if LOAD_VEC_A == 4 #define A_TYPE f16vec4 #elif LOAD_VEC_A == 8 #define A_TYPE f16mat2x4 +#else +#define A_TYPE float16_t #endif #endif @@ -37,12 +37,12 @@ #define QUANT_K 1 #define QUANT_R 1 -#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 -#define A_TYPE uint16_t -#elif LOAD_VEC_A == 4 +#if LOAD_VEC_A == 4 #define A_TYPE u16vec4 #elif LOAD_VEC_A == 8 #error unsupported +#else +#define A_TYPE uint16_t #endif #endif @@ -207,6 +207,18 @@ struct block_q8_1_packed32 int32_t qs[8]; }; +// 4 blocks in one to allow 16-byte/128-bit alignment and loads +struct block_q8_1_x4 +{ + f16vec2 ds[4]; + int32_t qs[32]; +}; +struct block_q8_1_x4_packed128 +{ + f16vec2 ds[4]; + ivec4 qs[8]; +}; + // K-quants #define QUANT_K_Q2_K 256 @@ -1412,6 +1424,11 @@ float bf16_to_fp32(uint32_t u) return uintBitsToFloat(u << 16); } +vec4 bf16_to_fp32(uvec4 u) +{ + return vec4(bf16_to_fp32(u.x), bf16_to_fp32(u.y), bf16_to_fp32(u.z), bf16_to_fp32(u.w)); +} + float e8m0_to_fp32(uint8_t x) { uint32_t bits; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 75c572d6fc9..74a4794d34f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -68,6 +68,12 @@ const std::vector type_names = { "bf16", }; +enum MatMulIdType { + NONE, + DEFAULT, + SUBGROUP, +}; + namespace { void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { #ifdef _WIN32 @@ -200,6 +206,22 @@ bool string_ends_with(const std::string& str, const std::string& suffix) { return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); } +bool is_quantized_type(const std::string& type_name) { + return type_name != "f32" && type_name != "f16" && type_name != "bf16"; +} + +bool is_legacy_quant(const std::string& type_name) { + return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0"; +} + +bool is_k_quant(const std::string& type_name) { + return string_ends_with(type_name, "_k"); +} + +bool is_iq_quant(const std::string& type_name) { + return string_starts_with(type_name, "iq"); +} + static const char path_separator = '/'; std::string join_paths(const std::string& path1, const std::string& path2) { @@ -293,26 +315,32 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc)); } -void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) { +void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) { std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; - std::map base_dict = { - {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"}, - }; + std::map base_dict; std::string shader_name = "matmul"; - if (matmul_id) { + if (matmul_id_type == MatMulIdType::DEFAULT) { base_dict["MUL_MAT_ID"] = "1"; shader_name = "matmul_id"; + } else if (matmul_id_type == MatMulIdType::SUBGROUP) { + base_dict["MUL_MAT_ID"] = "1"; + base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1"; + shader_name = "matmul_id_subgroup"; } if (fp16) { base_dict["FLOAT16"] = "1"; } - base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float"; + base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2"; + if (f16acc) { + base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; + } if (coopmat) { base_dict["COOPMAT"] = "1"; @@ -320,43 +348,96 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; - auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string { - if (t == "bf16") { - // scalar path promotes to float - if (!coopmat && !coopmat2) { - return "float"; + auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string { + switch (vec) { + case 1: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "float"; + } + return "bfloat16_t"; } - return "bfloat16_t"; - } - if (coopmat2 || fp16) { - return "float16_t"; + if (coopmat2 || fp16) { + return "float16_t"; + } + return "float"; + case 2: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "vec2"; + } + return "bf16vec2"; + } + if (coopmat2 || fp16) { + return "f16vec2"; + } + return "vec2"; + case 4: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "vec4"; + } + return "bf16vec4"; + } + if (coopmat2 || fp16) { + return "f16vec4"; + } + return "vec4"; + case 8: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "mat2x4"; + } + throw std::runtime_error("bf16 vec8 not supported"); + } + if (coopmat2 || fp16) { + return "f16mat2x4"; + } + return "mat2x4"; + default: + throw std::runtime_error("invalid vector size"); } - return "float"; + }; + + const std::map float_type_dict_f16 = { + {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")}, + {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")}, + {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")}, + {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")}, }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); // bf16 { - std::string load_vec_a_unaligned = "1"; // For aligned matmul loads std::string load_vec_a = coopmat2 ? "1" : "4"; // scalar path promotes to float std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32"; + const std::map float_type_dict_bf16 = { + {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")}, + {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")}, + {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")}, + }; + // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader #if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) if (!(coopmat || coopmat2)) #endif { - string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } } @@ -373,24 +454,31 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool std::string data_a_key = "DATA_A_" + to_uppercase(tname); // For unaligned, load one at a time for f32/f16, or two at a time for quants - std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant; + std::string load_vec_a_unaligned = coopmat2 ? "1" : (tname == "f32" || tname == "f16" || tname == "bf16") ? "2" : load_vec_quant; // For aligned matmul loads std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; + const std::map float_type_dict = { + {"FLOAT_TYPE", FLOAT_TYPE(1, tname)}, + {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)}, + {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)}, + {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)}, + }; + // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) { - string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); + if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) { + string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); } #endif } @@ -401,32 +489,38 @@ void process_shaders() { std::map base_dict = {{"FLOAT_TYPE", "float"}}; // matmul - for (const auto& matmul_id : {false, true}) { + for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) { // No coopmats // fp32 - matmul_shaders(false, matmul_id, false, false, false); + matmul_shaders(false, matmul_id_type, false, false, false); // fp16, fp32acc and fp16acc - matmul_shaders(true, matmul_id, false, false, false); - matmul_shaders(true, matmul_id, false, false, true); + matmul_shaders(true, matmul_id_type, false, false, false); + matmul_shaders(true, matmul_id_type, false, false, true); + if (matmul_id_type != MatMulIdType::DEFAULT) { #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - // Coopmat, fp32acc and fp16acc - matmul_shaders(true, matmul_id, true, false, false); - matmul_shaders(true, matmul_id, true, false, true); + // Coopmat, fp32acc and fp16acc + matmul_shaders(true, matmul_id_type, true, false, false); + matmul_shaders(true, matmul_id_type, true, false, true); #endif #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - // Coopmat2, fp32acc and fp16acc - matmul_shaders(true, matmul_id, false, true, false); - matmul_shaders(true, matmul_id, false, true, true); + // Coopmat2, fp32acc and fp16acc + matmul_shaders(true, matmul_id_type, false, true, false); + matmul_shaders(true, matmul_id_type, false, true, true); #endif + } } // flash attention for (const auto& f16acc : {false, true}) { - std::string acctype = f16acc ? "float16_t" : "float"; - std::string acctypev4 = f16acc ? "f16vec4" : "vec4"; + std::map fa_base_dict = base_dict; + fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; + if (f16acc) { + fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; + } for (const auto& tname : type_names) { if (tname == "f32") { @@ -437,30 +531,30 @@ void process_shaders() { #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc); } else { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); } #endif #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc); + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); } else if (tname == "q4_0" || tname == "q8_0") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); } #endif if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc); + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); } else if (tname == "q4_0" || tname == "q8_0") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); } } } @@ -476,8 +570,20 @@ void process_shaders() { string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + // mul mat vec with integer dot product +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (is_legacy_quant(tname)) { + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + } +#endif + // Dequant shaders if (tname != "f16" && tname != "bf16") { string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); @@ -503,6 +609,7 @@ void process_shaders() { string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -512,10 +619,14 @@ void process_shaders() { string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); + string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); + string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -538,13 +649,15 @@ void process_shaders() { s += std::string(dst_f16 ? "_f16" : "_f32"); return s; }; - for (std::string op : {"add", "sub", "mul", "div"}) { + for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) { for (auto src0_f16 : {false, true}) { for (auto src1_f16 : {false, true}) { for (auto dst_f16 : {false, true}) { for (auto rte : {false, true}) { + auto source = op == "add_rms" ? std::string("add") : op; auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : ""); - string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + auto add_rms = op == "add_rms" ? "1" : "0"; + string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}}); } } } @@ -557,7 +670,12 @@ void process_shaders() { string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); + string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); + string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}}); + + string_to_spv("quantize_q8_1_x4", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}}); + string_to_spv("quantize_q8_1_x4_subgroup", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}, {"USE_SUBGROUPS", "1"}}); string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -586,6 +704,8 @@ void process_shaders() { string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("exp_f16", "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("exp_f32", "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); @@ -600,6 +720,10 @@ void process_shaders() { string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("hardsigmoid_f16","hardsigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); for (auto rte : {false, true}) { std::string suffix = rte ? "_rte" : ""; @@ -652,6 +776,10 @@ void process_shaders() { string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); + string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); + string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); + string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -678,12 +806,15 @@ void process_shaders() { string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); + string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); + string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); + string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); for (auto &c : compiles) { c.wait(); @@ -741,7 +872,7 @@ void write_output_files() { } std::string suffixes[2] = {"_f32", "_f16"}; - for (const char *op : {"add", "sub", "mul", "div"}) { + for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) { fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op); fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op); std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = "; @@ -794,12 +925,21 @@ void write_output_files() { fputs(len.c_str(), src); } - for (const std::string& btype : {"f16", "f32"}) { + std::vector btypes = {"f16", "f32"}; + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + btypes.push_back("q8_1"); +#endif + + for (const std::string& btype : btypes) { for (const auto& tname : type_names) { - fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[2];\n", tname.c_str(), btype.c_str()); - fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[2];\n", tname.c_str(), btype.c_str()); - std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[2] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data};\n"; - std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[2] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len};\n"; + if (btype == "q8_1" && !is_legacy_quant(tname)) { + continue; + } + fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[3];\n", tname.c_str(), btype.c_str()); + fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[3];\n", tname.c_str(), btype.c_str()); + std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_data};\n"; + std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_len};\n"; fputs(data.c_str(), src); fputs(len.c_str(), src); } diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index 79ef68b85a4..78a985a4d16 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -20,8 +20,8 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR} COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8 ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py - --input "${SHADER_DIR}" - --output "${SHADER_HEADER}" + --input_dir "${SHADER_DIR}" + --output_file "${SHADER_HEADER}" DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py VERBATIM ) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index ba1addc8d9f..a92ddc582a3 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -116,17 +116,27 @@ struct webgpu_context_struct { wgpu::Queue queue; wgpu::Limits limits; - std::recursive_mutex mutex; + // Separate this out from limits since on some Metal systems, the limit returned by + // querying the limits is higher than the actual allowed maximum. + uint32_t max_wg_size_x; - bool device_init = false; + std::recursive_mutex mutex; webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; wgpu::ComputePipeline memset_pipeline; - wgpu::ComputePipeline mul_mat_pipeline; + wgpu::ComputePipeline mul_mat_pipeline[30][2]; wgpu::ComputePipeline set_rows_pipeline; + wgpu::ComputePipeline get_rows_pipeline[30]; + wgpu::ComputePipeline get_rows_f32_no_vec_pipeline; wgpu::ComputePipeline cpy_pipeline; + wgpu::ComputePipeline add_pipeline[2]; + wgpu::ComputePipeline add_ip_pipeline[2]; + wgpu::ComputePipeline mul_pipeline[2]; + wgpu::ComputePipeline mul_ip_pipeline[2]; + wgpu::ComputePipeline rms_norm_pipeline; + wgpu::ComputePipeline rms_norm_ip_pipeline; size_t memset_bytes_per_thread; @@ -234,14 +244,15 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) { std::lock_guard lock(ctx->mutex); if (ctx->callback_futures.empty()) { // no existing callbacks, wait on queue submission - ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); - } - }), - UINT64_MAX); + ctx->instance.WaitAny( + ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, + [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", + std::string(message).c_str()); + } + }), + UINT64_MAX); } else { // existing callbacks, wait on them ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX); @@ -278,7 +289,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) { wgpu::CallbackMode::AllowSpontaneous, [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); } // Free the staged buffers ctx->param_buf_pool.free_bufs(staged_param_bufs); @@ -288,13 +299,10 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) { // Check for errrors in SET_ROWS operations for (auto & error_bufs : staged_set_row_error_bufs) { wgpu::Future f = error_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, - 0, - error_bufs.host_buf.GetSize(), - wgpu::CallbackMode::AllowSpontaneous, + wgpu::MapMode::Read, 0, error_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", message.data); + GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); } else { const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange(); if (*error_data) { @@ -313,10 +321,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, wgpu::MapMode mode, size_t offset, size_t size) { - ctx->instance.WaitAny(buffer.MapAsync(mode, - offset, - size, - wgpu::CallbackMode::AllowSpontaneous, + ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, [](wgpu::MapAsyncStatus status, wgpu::StringView message) { if (status != wgpu::MapAsyncStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", @@ -331,6 +336,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and // debug statements in the shader, and then call this function after encoding the commands and submitting them. static void ggml_backend_webgpu_debug(webgpu_context & ctx) { + ggml_backend_webgpu_submit_queue(ctx); wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); @@ -352,7 +358,8 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & std::vector params, std::vector bind_group_entries, uint32_t wg_x, - bool submit_and_wait = false) { + const char * bind_group_label = nullptr, + bool submit_and_wait = false) { webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); @@ -373,6 +380,9 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & bind_group_desc.layout = pipeline.GetBindGroupLayout(0); bind_group_desc.entryCount = bind_group_entries.size(); bind_group_desc.entries = bind_group_entries.data(); + if (bind_group_label) { + bind_group_desc.label = bind_group_label; + } wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); @@ -416,18 +426,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, std::vector entries = { { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } }; - size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread; + size_t bytes_per_wg = ctx->max_wg_size_x * ctx->memset_bytes_per_thread; uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true); -} - -static size_t ggml_backend_webgpu_tensor_offset(const ggml_tensor * tensor) { - return webgpu_tensor_offset(tensor) + tensor->view_offs; -} - -static wgpu::Buffer ggml_backend_webgpu_tensor_buf(const ggml_tensor * tensor) { - ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; - return ctx->buffer; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, "MEMSET", true); } /** End WebGPU Actions */ @@ -447,50 +448,65 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { GGML_UNUSED(ctx); } +static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { + return webgpu_tensor_offset(tensor) + tensor->view_offs; +} + +static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { + ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + return ctx->buffer; +} + +static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); +} + +static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1); +} + +static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { + return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & + ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); +} + +// Used to determine if two tensors are the same for in-place operations +static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); +} + static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - size_t src_offset = ggml_backend_webgpu_tensor_offset(src); - // assumes power of 2 offset alignment - size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - // align to minimum offset alignment - src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst); - size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - uint32_t ne = (uint32_t) ggml_nelements(dst); - std::vector params = { ne, - (uint32_t) (src_misalignment / ggml_type_size(src->type)), - (uint32_t) (dst_misalignment / ggml_type_size(dst->type)), - // Convert byte-strides to element-strides - (uint32_t) (src->nb[0] / ggml_type_size(src->type)), - (uint32_t) (src->nb[1] / ggml_type_size(src->type)), - (uint32_t) (src->nb[2] / ggml_type_size(src->type)), - (uint32_t) (src->nb[3] / ggml_type_size(src->type)), - (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Logical shape — same for both tensors even if permuted - (uint32_t) src->ne[0], - (uint32_t) src->ne[1], - (uint32_t) src->ne[2], - (uint32_t) src->ne[3] }; + uint32_t ne = (uint32_t) ggml_nelements(dst); + + std::vector params = { + ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + // Logical shape — same for both tensors even if permuted + (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3] + }; std::vector entries = { { .binding = 0, - .buffer = ggml_backend_webgpu_tensor_buf(src), - .offset = src_offset, - .size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & - ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, { .binding = 1, - .buffer = ggml_backend_webgpu_tensor_buf(dst), - .offset = dst_offset, - .size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & - ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) } + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; + size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x); + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op)); } static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { @@ -504,76 +520,103 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t error_bufs.host_buf.Unmap(); } - size_t src_offset = ggml_backend_webgpu_tensor_offset(src); - // assumes power of 2 offset alignment - size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - // align to minimum offset alignment - src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - size_t idx_offset = ggml_backend_webgpu_tensor_offset(idx); - size_t idx_misalignment = idx_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - idx_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst); - size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - - std::vector params = { (uint32_t) (src_misalignment / ggml_type_size(src->type)), - (uint32_t) (idx_misalignment / ggml_type_size(idx->type)), - (uint32_t) (dst_misalignment / ggml_type_size(dst->type)), - // Convert byte-strides to element-strides - (uint32_t) (src->nb[1] / ggml_type_size(src->type)), - (uint32_t) (src->nb[2] / ggml_type_size(src->type)), - (uint32_t) (src->nb[3] / ggml_type_size(src->type)), - (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), - (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), - (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), - (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Shape of src - (uint32_t) src->ne[0], - (uint32_t) src->ne[1], - (uint32_t) src->ne[2], - (uint32_t) src->ne[3], - // Shape of idx - (uint32_t) (idx->ne[1]), - (uint32_t) (idx->ne[2]) }; + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + // Shape of src + (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3], + // Shape of idx + (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) + }; std::vector entries = { { .binding = 0, - .buffer = ggml_backend_webgpu_tensor_buf(src), - .offset = ggml_backend_webgpu_tensor_offset(src), - .size = ggml_nbytes(src) }, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, { .binding = 1, - .buffer = ggml_backend_webgpu_tensor_buf(idx), - .offset = ggml_backend_webgpu_tensor_offset(idx), - .size = ggml_nbytes(idx) }, + .buffer = ggml_webgpu_tensor_buf(idx), + .offset = ggml_webgpu_tensor_align_offset(ctx, idx), + .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, { .binding = 2, - .buffer = ggml_backend_webgpu_tensor_buf(dst), - .offset = ggml_backend_webgpu_tensor_offset(dst), - .size = ggml_nbytes(dst) }, - { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() } + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() } }; - size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; + size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; std::lock_guard lock(ctx->mutex); ctx->staged_set_row_error_bufs.push_back(error_bufs); - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x); + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op)); +} + +static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + // Shape of dst + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], + // Shape of idx + (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(idx), + .offset = ggml_webgpu_tensor_align_offset(ctx, idx), + .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (dst->ne[1] * dst->ne[2] * dst->ne[3] + max_wg_size - 1) / max_wg_size; + + wgpu::ComputePipeline pipeline = ctx->get_rows_pipeline[src->type]; + if (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 != 0) { + pipeline = ctx->get_rows_f32_no_vec_pipeline; + } + ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); } static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) dst->ne[1], // number of rows in result (M) (uint32_t) dst->ne[0], // number of columns in result (N) (uint32_t) src0->ne[0], // number of columns in src0/src1 (K) - (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 1 - (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 1 - (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 2 - (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 2 - (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 3 - (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 3 + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1 + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1 + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2 + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2 + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3 + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3 (uint32_t) src0->ne[2], // batch size in dimension 2 (uint32_t) src0->ne[3], // batch size in dimension 3 (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2 @@ -582,22 +625,119 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t std::vector entries = { { .binding = 0, - .buffer = ggml_backend_webgpu_tensor_buf(src0), - .offset = ggml_backend_webgpu_tensor_offset(src0), - .size = ggml_nbytes(src0) }, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, { .binding = 1, - .buffer = ggml_backend_webgpu_tensor_buf(src1), - .offset = ggml_backend_webgpu_tensor_offset(src1), - .size = ggml_nbytes(src1) }, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, { .binding = 2, - .buffer = ggml_backend_webgpu_tensor_buf(dst), - .offset = ggml_backend_webgpu_tensor_offset(dst), - .size = ggml_nbytes(dst) } + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, }; uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x); + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x, + ggml_op_name(dst->op)); +} + +static void ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + wgpu::ComputePipeline & pipeline, + bool in_place) { + std::vector params = { + (uint32_t) ggml_nelements(dst), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) } + }; + if (!in_place) { + entries.push_back({ .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); +} + +static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + bool in_place = ggml_webgpu_tensor_equal(src, dst); + + uint32_t eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + }; + if (!in_place) { + params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type))); + } + params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type))); + params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type))); + params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type))); + if (!in_place) { + params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type))); + params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type))); + params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type))); + } + params.push_back((uint32_t) src->ne[0]); + params.push_back((uint32_t) src->ne[1]); + params.push_back((uint32_t) src->ne[2]); + params.push_back((uint32_t) src->ne[3]); + params.push_back(eps); // epsilon, will be bitcast to float in shader + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) } + }; + if (!in_place) { + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + wgpu::ComputePipeline pipeline; + if (in_place) { + pipeline = ctx->rms_norm_ip_pipeline; + } else { + pipeline = ctx->rms_norm_pipeline; + } + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); } // Returns true if node has enqueued work into the queue, false otherwise @@ -615,22 +755,38 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { case GGML_OP_NONE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_RESHAPE: return false; case GGML_OP_CPY: - { - ggml_webgpu_cpy(ctx, src0, node); - break; - } + ggml_webgpu_cpy(ctx, src0, node); + break; case GGML_OP_SET_ROWS: - { - ggml_webgpu_set_rows(ctx, src0, src1, node); - break; - } + ggml_webgpu_set_rows(ctx, src0, src1, node); + break; + case GGML_OP_GET_ROWS: + ggml_webgpu_get_rows(ctx, src0, src1, node); + break; case GGML_OP_MUL_MAT: - { - ggml_webgpu_mul_mat(ctx, src0, src1, node); - break; + ggml_webgpu_mul_mat(ctx, src0, src1, node); + break; + case GGML_OP_ADD: + if (ggml_webgpu_tensor_equal(src0, node)) { + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true); + } else { + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false); } + break; + case GGML_OP_MUL: + if (ggml_webgpu_tensor_equal(src0, node)) { + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true); + } else { + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false); + } + break; + case GGML_OP_RMS_NORM: + ggml_webgpu_rms_norm(ctx, src0, node); + break; default: return false; } @@ -667,6 +823,7 @@ static ggml_backend_i ggml_backend_webgpu_i = { /* .graph_compute = */ ggml_backend_webgpu_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, }; /* End GGML Backend Interface */ @@ -732,8 +889,8 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; } // memset the remaining bytes - ggml_backend_webgpu_buffer_memset( - webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size); + ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), + remaining_size); } else { // wait for WriteBuffer to complete ggml_backend_webgpu_wait_on_submission(webgpu_ctx); @@ -767,11 +924,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, if (webgpu_ctx->get_tensor_staging_buf) { webgpu_ctx->get_tensor_staging_buf.Destroy(); } - ggml_webgpu_create_buffer(device, - webgpu_ctx->get_tensor_staging_buf, - final_size, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, - "get_tensor_staging_buf"); + ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf"); } // Copy the data from the buffer to the staging buffer @@ -825,9 +979,8 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); wgpu::Buffer buf; - ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, - buf, - size, + ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, + (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1), wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, "allocated_buffer"); @@ -891,9 +1044,17 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { return reinterpret_cast((void *) guid_str); } +// The max workgroup size is a common constant +static std::vector ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) { + std::vector constants(1); + constants[0].key = "wg_size"; + constants[0].value = webgpu_ctx->max_wg_size_x; + return constants; +} + static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { // we use the maximum workgroup size for the memset pipeline - size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX; + size_t max_wg_size = webgpu_ctx->max_wg_size_x; size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled webgpu_ctx->memset_bytes_per_thread = @@ -907,22 +1068,142 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], + wgsl_mul_mat_f32_f32, "mul_mat_f32_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16], + wgsl_mul_mat_f16_f16, "mul_mat_f16_f16"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32], + wgsl_mul_mat_f16_f32, "mul_mat_f16_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32], + wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32], + wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32], + wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32], + wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32], + wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32], + wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32], + wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32], + wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32], + wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32], + wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32], + wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32], + wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32], + wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32], + wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32], + wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32], + wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32], + wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32], + wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], + wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants(1); - constants[0].key = "wg_size"; - constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; - ggml_webgpu_create_pipeline( - webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", + ggml_webgpu_max_wg_size_entry(webgpu_ctx)); +} + +static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec, + "get_rows_f32_vec", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32, + "get_rows_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F16], wgsl_get_rows_f16, + "get_rows_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_I32], wgsl_get_rows_i32, + "get_rows_i32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_0], wgsl_get_rows_q4_0, + "get_rows_q4_0", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_1], wgsl_get_rows_q4_1, + "get_rows_q4_1", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_0], wgsl_get_rows_q5_0, + "get_rows_q5_0", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_1], wgsl_get_rows_q5_1, + "get_rows_q5_1", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q8_0], wgsl_get_rows_q8_0, + "get_rows_q8_0", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q2_K], wgsl_get_rows_q2_k, + "get_rows_q2_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q3_K], wgsl_get_rows_q3_k, + "get_rows_q3_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_K], wgsl_get_rows_q4_k, + "get_rows_q4_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_K], wgsl_get_rows_q5_k, + "get_rows_q5_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q6_K], wgsl_get_rows_q6_k, + "get_rows_q6_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XXS], + wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XS], + wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_S], wgsl_get_rows_iq2_s, + "get_rows_iq2_s", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_XXS], + wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_S], wgsl_get_rows_iq3_s, + "get_rows_iq3_s", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_S], wgsl_get_rows_iq1_s, + "get_rows_iq1_s", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_M], wgsl_get_rows_iq1_m, + "get_rows_iq1_m", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_NL], + wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_XS], + wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); } static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants(1); - constants[0].key = "wg_size"; - constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", + ggml_webgpu_max_wg_size_entry(webgpu_ctx)); +} + +static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32, + "add_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16, + "add_in_place_f16", constants); +} + +static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32, + "mul_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16, + "mul_in_place_f16", constants); +} + +static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place, + "rms_norm_in_place", constants); } static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { @@ -933,79 +1214,6 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context); webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx; - // Multiple threads may try to initialize the device - std::lock_guard lock(webgpu_ctx->mutex); - if (!webgpu_ctx->device_init) { - // Initialize device - std::vector required_features = { wgpu::FeatureName::ShaderF16, - wgpu::FeatureName::ImplicitDeviceSynchronization }; - wgpu::DeviceDescriptor dev_desc; - dev_desc.requiredLimits = &webgpu_ctx->limits; - dev_desc.requiredFeatures = required_features.data(); - dev_desc.requiredFeatureCount = required_features.size(); - dev_desc.SetDeviceLostCallback( - wgpu::CallbackMode::AllowSpontaneous, - [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { - GGML_UNUSED(device); - GGML_LOG_ERROR( - "ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), message.data); - }); - dev_desc.SetUncapturedErrorCallback( - [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { - GGML_UNUSED(device); - GGML_LOG_ERROR( - "ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), message.data); - }); - webgpu_ctx->instance.WaitAny( - webgpu_ctx->adapter.RequestDevice( - &dev_desc, - wgpu::CallbackMode::AllowSpontaneous, - [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { - if (status != wgpu::RequestDeviceStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data); - return; - } - webgpu_ctx->device = std::move(device); - }), - UINT64_MAX); - GGML_ASSERT(webgpu_ctx->device != nullptr); - - // Initialize (compute) queue - webgpu_ctx->queue = webgpu_ctx->device.GetQueue(); - - // Create buffer pool for shader parameters - webgpu_ctx->param_buf_pool.init(webgpu_ctx->device, - WEBGPU_NUM_PARAM_BUFS, - WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); - webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->device, - WEBGPU_NUM_SET_ROWS_ERROR_BUFS, - WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); - - ggml_webgpu_init_memset_pipeline(webgpu_ctx); - ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); - ggml_webgpu_init_set_rows_pipeline(webgpu_ctx); - ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - -#ifdef GGML_WEBGPU_DEBUG - // Initialize debug buffers - ggml_webgpu_create_buffer(webgpu_ctx->device, - webgpu_ctx->debug_host_buf, - WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, - "debug_host_buf"); - ggml_webgpu_create_buffer(webgpu_ctx->device, - webgpu_ctx->debug_dev_buf, - WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), - wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, - "debug_dev_buf"); -#endif - webgpu_ctx->device_init = true; - } - static ggml_backend_webgpu_context backend_ctx; backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; backend_ctx.webgpu_ctx = webgpu_ctx; @@ -1045,21 +1253,124 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name; } +static bool ggml_webgpu_supported_qtype(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + return true; + default: + return false; + } +} + static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - GGML_UNUSED(dev); + ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); + + webgpu_context webgpu_ctx = ctx->webgpu_ctx; + + ggml_tensor * src0 = op->src[0]; + ggml_tensor * src1 = op->src[1]; + // on smaller devices (or CI), tensors may be larger than the max storage buffer size + if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || + (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || + (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) { + return false; + } + bool supports_op = false; switch (op->op) { case GGML_OP_NONE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: - return true; - case GGML_OP_CPY | GGML_OP_SET_ROWS: - return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_TRANSPOSE: + case GGML_OP_RESHAPE: + supports_op = true; + break; + case GGML_OP_ADD: + case GGML_OP_MUL: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) && + (op->src[1]->type == op->type); + break; + case GGML_OP_CPY: + case GGML_OP_SET_ROWS: + supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32); + break; + case GGML_OP_GET_ROWS: + if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || + op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) { + supports_op = (op->type == GGML_TYPE_F32); + } + break; case GGML_OP_MUL_MAT: - return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + { + switch (op->src[1]->type) { + case GGML_TYPE_F16: + supports_op = (op->src[0]->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + supports_op = true; + break; + default: + break; + } + default: + break; + } + break; + } + case GGML_OP_RMS_NORM: + supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; + break; default: - return false; + break; } +#ifdef GGML_WEBGPU_DEBUG + if (!supports_op) { + WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type) + << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null") + << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")); + } +#endif + return supports_op; } static struct ggml_backend_device_i ggml_backend_webgpu_device_i = { @@ -1105,38 +1416,95 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; wgpu::RequestAdapterOptions options = {}; - auto callback = - [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, void * userdata) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - *static_cast(userdata) = std::move(adapter); - }; - void * userdata = &ctx->adapter; - ctx->instance.WaitAny( - ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX); + ctx->instance.WaitAny(ctx->instance.RequestAdapter( + &options, wgpu::CallbackMode::AllowSpontaneous, + [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + ctx->adapter = std::move(adapter); + }), + UINT64_MAX); GGML_ASSERT(ctx->adapter != nullptr); ctx->adapter.GetLimits(&ctx->limits); + ctx->max_wg_size_x = 288; // default value wgpu::AdapterInfo info{}; ctx->adapter.GetInfo(&info); + // Initialize device + std::vector required_features = { wgpu::FeatureName::ShaderF16, + wgpu::FeatureName::ImplicitDeviceSynchronization }; + wgpu::DeviceDescriptor dev_desc; + dev_desc.requiredLimits = &ctx->limits; + dev_desc.requiredFeatures = required_features.data(); + dev_desc.requiredFeatureCount = required_features.size(); + dev_desc.SetDeviceLostCallback( + wgpu::CallbackMode::AllowSpontaneous, + [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + GGML_UNUSED(device); + GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), + std::string(message).c_str()); + }); + dev_desc.SetUncapturedErrorCallback( + [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { + GGML_UNUSED(device); + GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), + std::string(message).c_str()); + }); + ctx->instance.WaitAny(ctx->adapter.RequestDevice( + &dev_desc, wgpu::CallbackMode::AllowSpontaneous, + [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { + if (status != wgpu::RequestDeviceStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", + std::string(message).c_str()); + return; + } + ctx->device = std::move(device); + }), + UINT64_MAX); + GGML_ASSERT(ctx->device != nullptr); + + // Initialize (compute) queue + ctx->queue = ctx->device.GetQueue(); + + // Create buffer pool for shader parameters + ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); + + ggml_webgpu_init_memset_pipeline(ctx); + ggml_webgpu_init_mul_mat_pipeline(ctx); + ggml_webgpu_init_set_rows_pipeline(ctx); + ggml_webgpu_init_get_rows_pipeline(ctx); + ggml_webgpu_init_cpy_pipeline(ctx); + ggml_webgpu_init_add_pipeline(ctx); + ggml_webgpu_init_mul_pipeline(ctx); + ggml_webgpu_init_rms_norm_pipeline(ctx); + +#ifdef GGML_WEBGPU_DEBUG + // Initialize debug buffers + ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf"); + ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf"); +#endif + static ggml_backend_webgpu_device_context device_ctx; device_ctx.webgpu_ctx = ctx; device_ctx.device_name = GGML_WEBGPU_NAME; - device_ctx.device_desc = std::string(info.description.data); + device_ctx.device_desc = info.description; GGML_LOG_INFO( "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " "device_desc: %s\n", - info.vendorID, - info.vendor.data, - info.architecture.data, - info.deviceID, - info.device.data, - info.description.data); + info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID, + std::string(info.device).c_str(), std::string(info.description).c_str()); // See GGML Backend Device Interface section static ggml_backend_device device = { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl new file mode 100644 index 00000000000..f261cbb553a --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl @@ -0,0 +1,44 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "f32", + } + }, + { + "REPLS": { + "TYPE" : "f16", + } + } +] + +#end(VARIANTS) + +#define(SHADER) + +enable f16; + +#include "binary_head.tmpl" + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)]; + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl new file mode 100644 index 00000000000..903f7bdbcc5 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl @@ -0,0 +1,41 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "f32", + } + }, + { + "REPLS": { + "TYPE" : "f16", + } + } +] + +#end(VARIANTS) + +#define(SHADER) + +enable f16; + +#include "binary_head.tmpl" + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)]; + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl new file mode 100644 index 00000000000..4b254f468d6 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl @@ -0,0 +1,45 @@ +struct Params { + ne: u32, + + // offsets in elements + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src1_0: u32, + stride_src1_1: u32, + stride_src1_2: u32, + stride_src1_3: u32, + + a_ne0: u32, + a_ne1: u32, + a_ne2: u32, + + b_ne0: u32, + b_ne1: u32, + b_ne2: u32, + b_ne3: u32, +}; + +fn src1_index(_i: u32) -> u32 { + var i = _i; + let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); + i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); + let a_i2 = i / (params.a_ne1 * params.a_ne0); + i = i % (params.a_ne1 * params.a_ne0); + let a_i1 = i / params.a_ne0; + let a_i0 = i % params.a_ne0; + + // handle repetition of b + // index loops back to the beginning and repeats after elements are exhausted = modulo + let b_i0 = a_i0 % params.b_ne0; + let b_i1 = a_i1 % params.b_ne1; + let b_i2 = a_i2 % params.b_ne2; + let b_i3 = a_i3 % params.b_ne3; + + // compute index for position in b's flat array + return b_i0 * params.stride_src1_0 + + b_i1 * params.stride_src1_1 + + b_i2 * params.stride_src1_2 + + b_i3 * params.stride_src1_3; +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl new file mode 100644 index 00000000000..389c97bb51b --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -0,0 +1,930 @@ +#decl(BYTE_HELPERS) + +fn get_byte(value: u32, index: u32) -> u32 { + return (value >> (index * 8)) & 0xFF; +} + +fn get_byte_i32(value: u32, index: u32) -> i32 { + return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; +} + +#enddecl(BYTE_HELPERS) + +#decl(Q4_0_T) +struct q4_0 { + d: f16, + qs: array +}; +#enddecl(Q4_0_T) + +#decl(Q4_1_T) +struct q4_1 { + d: f16, + m: f16, + qs: array +}; +#enddecl(Q4_1_T) + +#decl(Q5_0_T) +struct q5_0 { + d: f16, + qh: array, + qs: array +}; +#enddecl(Q5_0_T) + +#decl(Q5_1_T) +struct q5_1 { + d: f16, + m: f16, + qh: u32, + qs: array +}; +#enddecl(Q5_1_T) + +#decl(Q8_0_T) +struct q8_0 { + d: f16, + qs: array +}; +#enddecl(Q8_0_T) + +#decl(Q8_1_T) +struct q8_1 { + d: f16, + m: f16, + qs: array +}; +#enddecl(Q8_1_T) + +#decl(Q2_K_T) +struct q2_k { + scales: array, + qs: array, + d: f16, + dmin: f16 +}; +#enddecl(Q2_K_T) + +#decl(Q3_K_T) +struct q3_k { + hmask: array, + qs: array, + scales: array, + d: f16 +}; +#enddecl(Q3_K_T) + +#decl(Q45_K_SCALE_MIN) + +fn get_scale_min(is: u32, scales: array) -> vec2 { + if (is < 4) { + let sc_byte = get_byte(scales[is / 4], is % 4); + let min_byte = get_byte(scales[(is + 4) / 4], is % 4); + return vec2(f32(sc_byte & 63), f32(min_byte & 63)); + } else { + let sc_min_lo = get_byte(scales[(is + 4) / 4], (is + 4) % 4); + let sc_hi = get_byte(scales[(is - 4) / 4], (is - 4) % 4); + let min_hi = get_byte(scales[is / 4], is % 4); + let sc = (sc_min_lo & 0xF) | ((sc_hi >> 6) << 4); + let m = (sc_min_lo >> 4) | ((min_hi >> 6) << 4); + return vec2(f32(sc), f32(m)); + } +} + +#enddecl(Q45_K_SCALE_MIN) + +#decl(Q4_K_T) +struct q4_k { + d: f16, + dmin: f16, + scales: array, + qs: array +}; +#enddecl(Q4_K_T) + +#decl(Q5_K_T) +struct q5_k { + d: f16, + dmin: f16, + scales: array, + qh: array, + qs: array +}; +#enddecl(Q5_K_T) + +#decl(Q6_K_T) +struct q6_k { + ql: array, + qh: array, + scales: array, + d: f16 +}; +#enddecl(Q6_K_T) + +#decl(IQ2_XXS_T) +struct iq2_xxs { + d: f16, + qs: array +}; +#enddecl(IQ2_XXS_T) + +#decl(IQ2_XS_T) +struct iq2_xs { + d: f16, + qs: array, + scales: array +}; +#enddecl(IQ2_XS_T) + +#decl(IQ2_S_T) +struct iq2_s { + d: f16, + qs: array, + qh: array, + scales: array +}; +#enddecl(IQ2_S_T) + +#decl(IQ3_XSS_T) +struct iq3_xxs { + d: f16, + qs: array +}; +#enddecl(IQ3_XSS_T) + +#decl(IQ3_S_T) +struct iq3_s { + d: f16, + qs: array, + qh: array, + signs: array, + scales: array +}; +#enddecl(IQ3_S_T) + +#decl(IQ1_S_T) +struct iq1_s { + d: f16, + qs: array, + qh: array +}; +#enddecl(IQ1_S_T) + +#decl(IQ1_M_T) +struct iq1_m { + qs: array, + qh: array, + scales: array +}; +#enddecl(IQ1_M_T) + +#decl(IQ4_NL_T) +struct iq4_nl { + d: f16, + qs: array, +}; +#enddecl(IQ4_NL_T) + +#decl(IQ4_XS_T) +struct iq4_xs { + d: f16, + scales_h: f16, + scales_l: u32, + qs: array +}; +#enddecl(IQ4_XS_T) + +#decl(IQ23_TABLES) +const kmask_iq2xs : array = array( + 0x08040201u, // 1, 2, 4, 8 + 0x80402010u // 16, 32, 64, 128 +); + +const ksigns_iq2xs: array = array( + 0x03828100,0x87060584,0x8b0a0988,0x0f8e8d0c, + 0x93121190,0x17969514,0x1b9a9918,0x9f1e1d9c, + 0xa32221a0,0x27a6a524,0x2baaa928,0xaf2e2dac, + 0x33b2b130,0xb73635b4,0xbb3a39b8,0x3fbebd3c, + 0xc34241c0,0x47c6c544,0x4bcac948,0xcf4e4dcc, + 0x53d2d150,0xd75655d4,0xdb5a59d8,0x5fdedd5c, + 0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c, + 0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc +); +#enddecl(IQ23_TABLES) + +#decl(IQ2_XXS_GRID) +const iq2xxs_grid = array( + 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, + 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808, + 0x082b082b, 0x08080808, 0x082b2b08, 0x08080808, 0x082b2b2b, 0x08080808, 0x19080819, 0x08080808, + 0x19081908, 0x08080808, 0x19190808, 0x08080808, 0x19192b08, 0x08080808, 0x192b0819, 0x08080808, + 0x192b1908, 0x08080808, 0x2b080808, 0x08080808, 0x2b08082b, 0x08080808, 0x2b082b2b, 0x08080808, + 0x2b2b082b, 0x08080808, 0x08080819, 0x08080819, 0x08081908, 0x08080819, 0x08190808, 0x08080819, + 0x08191919, 0x08080819, 0x19080808, 0x08080819, 0x2b081908, 0x08080819, 0x2b192b08, 0x08080819, + 0x08080808, 0x0808082b, 0x0808082b, 0x0808082b, 0x082b082b, 0x0808082b, 0x2b08082b, 0x0808082b, + 0x08080819, 0x08081908, 0x08081908, 0x08081908, 0x08190808, 0x08081908, 0x082b0819, 0x08081908, + 0x082b1908, 0x08081908, 0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19082b08, 0x08081908, + 0x192b0808, 0x08081908, 0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b190808, 0x08081908, + 0x2b2b1908, 0x08081908, 0x08080808, 0x08081919, 0x0808082b, 0x08081919, 0x08082b08, 0x08081919, + 0x082b0808, 0x08081919, 0x1908192b, 0x08081919, 0x192b2b19, 0x08081919, 0x2b080808, 0x08081919, + 0x2b190819, 0x08081919, 0x08082b19, 0x0808192b, 0x08190808, 0x0808192b, 0x19080808, 0x0808192b, + 0x2b081908, 0x0808192b, 0x2b2b1908, 0x0808192b, 0x08080808, 0x08082b08, 0x08081919, 0x08082b08, + 0x08082b08, 0x08082b08, 0x08191908, 0x08082b08, 0x082b2b08, 0x08082b08, 0x19080819, 0x08082b08, + 0x19081908, 0x08082b08, 0x19190808, 0x08082b08, 0x1919082b, 0x08082b08, 0x2b082b08, 0x08082b08, + 0x08081908, 0x08082b19, 0x19080808, 0x08082b19, 0x0808082b, 0x08082b2b, 0x08191908, 0x08082b2b, + 0x08080819, 0x08190808, 0x08081908, 0x08190808, 0x08190808, 0x08190808, 0x082b0819, 0x08190808, + 0x19080808, 0x08190808, 0x192b0808, 0x08190808, 0x2b081908, 0x08190808, 0x2b190808, 0x08190808, + 0x2b191919, 0x08190808, 0x08080808, 0x08190819, 0x08082b08, 0x08190819, 0x082b0808, 0x08190819, + 0x19190808, 0x08190819, 0x19192b2b, 0x08190819, 0x2b080808, 0x08190819, 0x082b1908, 0x0819082b, + 0x19081919, 0x0819082b, 0x08080808, 0x08191908, 0x08082b08, 0x08191908, 0x082b0808, 0x08191908, + 0x082b1919, 0x08191908, 0x19082b19, 0x08191908, 0x2b080808, 0x08191908, 0x08192b08, 0x08191919, + 0x192b082b, 0x08191919, 0x08080808, 0x0819192b, 0x0819192b, 0x0819192b, 0x08080819, 0x08192b08, + 0x08081908, 0x08192b08, 0x08190808, 0x08192b08, 0x19080808, 0x08192b08, 0x2b080819, 0x08192b08, + 0x08080808, 0x08192b19, 0x08081919, 0x08192b19, 0x2b2b0808, 0x08192b19, 0x19190819, 0x08192b2b, + 0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08082b2b, 0x082b0808, 0x19081908, 0x082b0808, + 0x192b0819, 0x082b0808, 0x2b080808, 0x082b0808, 0x2b08082b, 0x082b0808, 0x082b2b19, 0x082b0819, + 0x19082b08, 0x082b0819, 0x08080808, 0x082b082b, 0x0808082b, 0x082b082b, 0x08080819, 0x082b1908, + 0x08081908, 0x082b1908, 0x08190808, 0x082b1908, 0x19080808, 0x082b1908, 0x1919192b, 0x082b1908, + 0x08080808, 0x082b1919, 0x19080819, 0x082b1919, 0x192b1908, 0x082b1919, 0x2b190808, 0x082b192b, + 0x08082b08, 0x082b2b08, 0x082b0808, 0x082b2b08, 0x2b191908, 0x082b2b08, 0x19081908, 0x082b2b2b, + 0x08080819, 0x19080808, 0x08081908, 0x19080808, 0x08190808, 0x19080808, 0x08192b08, 0x19080808, + 0x082b0819, 0x19080808, 0x082b1908, 0x19080808, 0x19080808, 0x19080808, 0x19082b08, 0x19080808, + 0x1919192b, 0x19080808, 0x192b0808, 0x19080808, 0x2b080819, 0x19080808, 0x2b081908, 0x19080808, + 0x2b190808, 0x19080808, 0x08080808, 0x19080819, 0x082b0808, 0x19080819, 0x192b0819, 0x19080819, + 0x2b080808, 0x19080819, 0x2b081919, 0x19080819, 0x08080819, 0x1908082b, 0x08190808, 0x1908082b, + 0x19082b08, 0x1908082b, 0x1919192b, 0x1908082b, 0x192b2b08, 0x1908082b, 0x08080808, 0x19081908, + 0x08082b08, 0x19081908, 0x082b0808, 0x19081908, 0x2b080808, 0x19081908, 0x2b192b19, 0x19081908, + 0x0819082b, 0x19081919, 0x082b1908, 0x19081919, 0x08080808, 0x1908192b, 0x08080819, 0x19082b08, + 0x08081908, 0x19082b08, 0x08190808, 0x19082b08, 0x19080808, 0x19082b08, 0x19081919, 0x19082b08, + 0x08080808, 0x19082b19, 0x19192b08, 0x19082b19, 0x192b0819, 0x19082b19, 0x2b08082b, 0x19082b19, + 0x19081919, 0x19082b2b, 0x2b190808, 0x19082b2b, 0x08080808, 0x19190808, 0x08082b08, 0x19190808, + 0x08190819, 0x19190808, 0x08192b19, 0x19190808, 0x082b0808, 0x19190808, 0x2b080808, 0x19190808, + 0x2b082b08, 0x19190808, 0x08081908, 0x19190819, 0x1908082b, 0x19190819, 0x2b2b1908, 0x19190819, + 0x2b190819, 0x1919082b, 0x2b190808, 0x19191908, 0x2b19082b, 0x19191908, 0x08082b2b, 0x19191919, + 0x08080819, 0x1919192b, 0x19191908, 0x1919192b, 0x08080808, 0x19192b08, 0x08190819, 0x19192b08, + 0x08192b19, 0x19192b08, 0x192b1908, 0x19192b08, 0x19080808, 0x19192b19, 0x08082b08, 0x19192b2b, + 0x08081908, 0x192b0808, 0x08190808, 0x192b0808, 0x19080808, 0x192b0808, 0x192b2b08, 0x192b0808, + 0x08080808, 0x192b0819, 0x19191919, 0x192b0819, 0x08192b08, 0x192b082b, 0x192b0808, 0x192b082b, + 0x08080808, 0x192b1908, 0x08081919, 0x192b1908, 0x08190808, 0x192b1919, 0x0819082b, 0x192b1919, + 0x2b081908, 0x192b1919, 0x1908082b, 0x192b2b08, 0x08080808, 0x2b080808, 0x0808082b, 0x2b080808, + 0x08082b2b, 0x2b080808, 0x19080819, 0x2b080808, 0x2b08082b, 0x2b080808, 0x08081908, 0x2b080819, + 0x08192b08, 0x2b080819, 0x19080808, 0x2b080819, 0x08190819, 0x2b08082b, 0x08080819, 0x2b081908, + 0x08081908, 0x2b081908, 0x08190808, 0x2b081908, 0x08191919, 0x2b081908, 0x19080808, 0x2b081908, + 0x192b0808, 0x2b081908, 0x08080808, 0x2b081919, 0x1908192b, 0x2b081919, 0x2b191908, 0x2b081919, + 0x08082b19, 0x2b08192b, 0x19080808, 0x2b08192b, 0x192b0808, 0x2b08192b, 0x0808082b, 0x2b082b08, + 0x08081908, 0x2b082b19, 0x08190819, 0x2b082b2b, 0x08081908, 0x2b190808, 0x08190808, 0x2b190808, + 0x082b1908, 0x2b190808, 0x19080808, 0x2b190808, 0x2b2b0819, 0x2b190808, 0x0819192b, 0x2b190819, + 0x2b080808, 0x2b190819, 0x19081919, 0x2b19082b, 0x08080808, 0x2b191908, 0x082b082b, 0x2b191908, + 0x19081908, 0x2b191908, 0x19190819, 0x2b191919, 0x2b080819, 0x2b192b08, 0x082b0808, 0x2b192b19, + 0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819, + 0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19 +); +#enddecl(IQ2_XXS_GRID) + +#decl(IQ2_XS_GRID) +const iq2xs_grid = array( + 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, + 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, + 0x08192b19, 0x08080808, 0x082b0808, 0x08080808, 0x082b082b, 0x08080808, 0x082b1919, 0x08080808, + 0x082b2b08, 0x08080808, 0x19080819, 0x08080808, 0x19081908, 0x08080808, 0x1908192b, 0x08080808, + 0x19082b19, 0x08080808, 0x19190808, 0x08080808, 0x1919082b, 0x08080808, 0x19191919, 0x08080808, + 0x19192b08, 0x08080808, 0x192b0819, 0x08080808, 0x192b1908, 0x08080808, 0x2b080808, 0x08080808, + 0x2b08082b, 0x08080808, 0x2b081919, 0x08080808, 0x2b082b08, 0x08080808, 0x2b190819, 0x08080808, + 0x2b191908, 0x08080808, 0x2b192b19, 0x08080808, 0x2b2b0808, 0x08080808, 0x08080819, 0x08080819, + 0x08081908, 0x08080819, 0x0808192b, 0x08080819, 0x08082b19, 0x08080819, 0x08190808, 0x08080819, + 0x0819082b, 0x08080819, 0x08191919, 0x08080819, 0x08192b08, 0x08080819, 0x08192b2b, 0x08080819, + 0x082b0819, 0x08080819, 0x082b1908, 0x08080819, 0x19080808, 0x08080819, 0x1908082b, 0x08080819, + 0x19081919, 0x08080819, 0x19082b08, 0x08080819, 0x19190819, 0x08080819, 0x19191908, 0x08080819, + 0x192b0808, 0x08080819, 0x192b2b08, 0x08080819, 0x2b080819, 0x08080819, 0x2b081908, 0x08080819, + 0x2b190808, 0x08080819, 0x08080808, 0x0808082b, 0x0808082b, 0x0808082b, 0x08081919, 0x0808082b, + 0x08082b08, 0x0808082b, 0x08190819, 0x0808082b, 0x08191908, 0x0808082b, 0x082b0808, 0x0808082b, + 0x19080819, 0x0808082b, 0x19081908, 0x0808082b, 0x19190808, 0x0808082b, 0x19191919, 0x0808082b, + 0x2b080808, 0x0808082b, 0x2b082b2b, 0x0808082b, 0x08080819, 0x08081908, 0x08081908, 0x08081908, + 0x0808192b, 0x08081908, 0x08082b19, 0x08081908, 0x08190808, 0x08081908, 0x0819082b, 0x08081908, + 0x08191919, 0x08081908, 0x08192b08, 0x08081908, 0x082b0819, 0x08081908, 0x082b1908, 0x08081908, + 0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19081919, 0x08081908, 0x19082b08, 0x08081908, + 0x19190819, 0x08081908, 0x19191908, 0x08081908, 0x1919192b, 0x08081908, 0x192b0808, 0x08081908, + 0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b190808, 0x08081908, 0x08080808, 0x08081919, + 0x0808082b, 0x08081919, 0x08081919, 0x08081919, 0x08082b08, 0x08081919, 0x08190819, 0x08081919, + 0x08191908, 0x08081919, 0x082b0808, 0x08081919, 0x19080819, 0x08081919, 0x19081908, 0x08081919, + 0x19190808, 0x08081919, 0x192b0819, 0x08081919, 0x2b080808, 0x08081919, 0x08080819, 0x0808192b, + 0x08081908, 0x0808192b, 0x08190808, 0x0808192b, 0x082b192b, 0x0808192b, 0x19080808, 0x0808192b, + 0x1908082b, 0x0808192b, 0x2b081908, 0x0808192b, 0x08080808, 0x08082b08, 0x0808082b, 0x08082b08, + 0x08081919, 0x08082b08, 0x08082b08, 0x08082b08, 0x08082b2b, 0x08082b08, 0x08190819, 0x08082b08, + 0x08191908, 0x08082b08, 0x082b0808, 0x08082b08, 0x082b1919, 0x08082b08, 0x19080819, 0x08082b08, + 0x19081908, 0x08082b08, 0x19190808, 0x08082b08, 0x19192b08, 0x08082b08, 0x2b080808, 0x08082b08, + 0x2b2b0808, 0x08082b08, 0x2b2b2b2b, 0x08082b08, 0x08080819, 0x08082b19, 0x08081908, 0x08082b19, + 0x08190808, 0x08082b19, 0x19080808, 0x08082b19, 0x2b080819, 0x08082b19, 0x2b082b19, 0x08082b19, + 0x08080808, 0x08082b2b, 0x082b0808, 0x08082b2b, 0x082b2b08, 0x08082b2b, 0x2b19192b, 0x08082b2b, + 0x2b2b0808, 0x08082b2b, 0x08080819, 0x08190808, 0x08081908, 0x08190808, 0x0808192b, 0x08190808, + 0x08082b19, 0x08190808, 0x08190808, 0x08190808, 0x0819082b, 0x08190808, 0x08191919, 0x08190808, + 0x08192b08, 0x08190808, 0x082b0819, 0x08190808, 0x082b1908, 0x08190808, 0x19080808, 0x08190808, + 0x1908082b, 0x08190808, 0x19081919, 0x08190808, 0x19082b08, 0x08190808, 0x19190819, 0x08190808, + 0x19191908, 0x08190808, 0x192b0808, 0x08190808, 0x192b2b2b, 0x08190808, 0x2b080819, 0x08190808, + 0x2b081908, 0x08190808, 0x2b190808, 0x08190808, 0x08080808, 0x08190819, 0x0808082b, 0x08190819, + 0x08081919, 0x08190819, 0x08082b08, 0x08190819, 0x08190819, 0x08190819, 0x08191908, 0x08190819, + 0x082b0808, 0x08190819, 0x19080819, 0x08190819, 0x19081908, 0x08190819, 0x19190808, 0x08190819, + 0x2b080808, 0x08190819, 0x2b191908, 0x08190819, 0x2b19192b, 0x08190819, 0x08080819, 0x0819082b, + 0x08081908, 0x0819082b, 0x0808192b, 0x0819082b, 0x08190808, 0x0819082b, 0x19080808, 0x0819082b, + 0x192b0808, 0x0819082b, 0x08080808, 0x08191908, 0x0808082b, 0x08191908, 0x08081919, 0x08191908, + 0x08082b08, 0x08191908, 0x08190819, 0x08191908, 0x08191908, 0x08191908, 0x082b0808, 0x08191908, + 0x19080819, 0x08191908, 0x19081908, 0x08191908, 0x19082b19, 0x08191908, 0x19190808, 0x08191908, + 0x192b1908, 0x08191908, 0x2b080808, 0x08191908, 0x08080819, 0x08191919, 0x08081908, 0x08191919, + 0x08190808, 0x08191919, 0x19080808, 0x08191919, 0x08080808, 0x0819192b, 0x08191908, 0x0819192b, + 0x19082b19, 0x0819192b, 0x08080819, 0x08192b08, 0x08081908, 0x08192b08, 0x08190808, 0x08192b08, + 0x0819082b, 0x08192b08, 0x19080808, 0x08192b08, 0x19191908, 0x08192b08, 0x2b08192b, 0x08192b08, + 0x08080808, 0x08192b19, 0x08081919, 0x08192b19, 0x192b192b, 0x08192b19, 0x19190819, 0x08192b2b, + 0x2b2b2b19, 0x08192b2b, 0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08081919, 0x082b0808, + 0x08082b08, 0x082b0808, 0x08082b2b, 0x082b0808, 0x08190819, 0x082b0808, 0x08191908, 0x082b0808, + 0x082b0808, 0x082b0808, 0x19080819, 0x082b0808, 0x19081908, 0x082b0808, 0x19190808, 0x082b0808, + 0x2b080808, 0x082b0808, 0x2b2b0808, 0x082b0808, 0x08080819, 0x082b0819, 0x08081908, 0x082b0819, + 0x08190808, 0x082b0819, 0x19080808, 0x082b0819, 0x19082b08, 0x082b0819, 0x192b1919, 0x082b0819, + 0x08080808, 0x082b082b, 0x082b082b, 0x082b082b, 0x2b080808, 0x082b082b, 0x2b2b2b08, 0x082b082b, + 0x08080819, 0x082b1908, 0x08081908, 0x082b1908, 0x08190808, 0x082b1908, 0x082b2b19, 0x082b1908, + 0x19080808, 0x082b1908, 0x08080808, 0x082b1919, 0x19080819, 0x082b1919, 0x1919082b, 0x082b1919, + 0x2b192b19, 0x082b1919, 0x08080819, 0x082b192b, 0x08192b2b, 0x082b192b, 0x2b2b192b, 0x082b192b, + 0x08080808, 0x082b2b08, 0x08082b08, 0x082b2b08, 0x08082b2b, 0x082b2b08, 0x082b0808, 0x082b2b08, + 0x19191919, 0x082b2b08, 0x2b082b08, 0x082b2b08, 0x2b2b082b, 0x082b2b08, 0x192b2b08, 0x082b2b19, + 0x2b190808, 0x082b2b19, 0x08082b08, 0x082b2b2b, 0x082b0808, 0x082b2b2b, 0x2b08082b, 0x082b2b2b, + 0x2b082b08, 0x082b2b2b, 0x2b082b2b, 0x082b2b2b, 0x08080819, 0x19080808, 0x08081908, 0x19080808, + 0x0808192b, 0x19080808, 0x08082b19, 0x19080808, 0x08190808, 0x19080808, 0x0819082b, 0x19080808, + 0x08191919, 0x19080808, 0x08192b08, 0x19080808, 0x082b0819, 0x19080808, 0x082b1908, 0x19080808, + 0x19080808, 0x19080808, 0x1908082b, 0x19080808, 0x19081919, 0x19080808, 0x19082b08, 0x19080808, + 0x19082b2b, 0x19080808, 0x19190819, 0x19080808, 0x19191908, 0x19080808, 0x192b0808, 0x19080808, + 0x192b1919, 0x19080808, 0x2b080819, 0x19080808, 0x2b081908, 0x19080808, 0x2b190808, 0x19080808, + 0x08080808, 0x19080819, 0x0808082b, 0x19080819, 0x08081919, 0x19080819, 0x08082b08, 0x19080819, + 0x08190819, 0x19080819, 0x08191908, 0x19080819, 0x082b0808, 0x19080819, 0x19080819, 0x19080819, + 0x19081908, 0x19080819, 0x19190808, 0x19080819, 0x2b080808, 0x19080819, 0x2b081919, 0x19080819, + 0x2b2b082b, 0x19080819, 0x08080819, 0x1908082b, 0x08081908, 0x1908082b, 0x08190808, 0x1908082b, + 0x0819082b, 0x1908082b, 0x082b2b19, 0x1908082b, 0x19080808, 0x1908082b, 0x08080808, 0x19081908, + 0x0808082b, 0x19081908, 0x08081919, 0x19081908, 0x08082b08, 0x19081908, 0x08190819, 0x19081908, + 0x08191908, 0x19081908, 0x08192b19, 0x19081908, 0x082b0808, 0x19081908, 0x19080819, 0x19081908, + 0x19081908, 0x19081908, 0x19190808, 0x19081908, 0x2b080808, 0x19081908, 0x2b191908, 0x19081908, + 0x08080819, 0x19081919, 0x08081908, 0x19081919, 0x08190808, 0x19081919, 0x082b1908, 0x19081919, + 0x19080808, 0x19081919, 0x2b192b2b, 0x19081919, 0x08080808, 0x1908192b, 0x08082b2b, 0x1908192b, + 0x19081908, 0x1908192b, 0x19190808, 0x1908192b, 0x08080819, 0x19082b08, 0x08081908, 0x19082b08, + 0x08190808, 0x19082b08, 0x19080808, 0x19082b08, 0x19081919, 0x19082b08, 0x19191908, 0x19082b08, + 0x192b082b, 0x19082b08, 0x08080808, 0x19082b19, 0x08190819, 0x19082b19, 0x19081908, 0x19082b19, + 0x19190808, 0x19082b19, 0x192b2b19, 0x19082b19, 0x08081908, 0x19082b2b, 0x08080808, 0x19190808, + 0x0808082b, 0x19190808, 0x08081919, 0x19190808, 0x08082b08, 0x19190808, 0x08190819, 0x19190808, + 0x08191908, 0x19190808, 0x082b0808, 0x19190808, 0x082b2b08, 0x19190808, 0x19080819, 0x19190808, + 0x19081908, 0x19190808, 0x19190808, 0x19190808, 0x2b080808, 0x19190808, 0x08080819, 0x19190819, + 0x08081908, 0x19190819, 0x08190808, 0x19190819, 0x08191919, 0x19190819, 0x19080808, 0x19190819, + 0x1908082b, 0x19190819, 0x08080808, 0x1919082b, 0x19081908, 0x1919082b, 0x2b2b2b2b, 0x1919082b, + 0x08080819, 0x19191908, 0x08081908, 0x19191908, 0x08190808, 0x19191908, 0x082b0819, 0x19191908, + 0x19080808, 0x19191908, 0x192b0808, 0x19191908, 0x2b080819, 0x19191908, 0x2b2b0819, 0x19191908, + 0x08080808, 0x19191919, 0x08082b08, 0x19191919, 0x2b080808, 0x19191919, 0x2b082b08, 0x19191919, + 0x082b0819, 0x1919192b, 0x192b2b08, 0x1919192b, 0x2b2b0819, 0x1919192b, 0x08080808, 0x19192b08, + 0x08191908, 0x19192b08, 0x19080819, 0x19192b08, 0x19190808, 0x19192b08, 0x2b192b19, 0x19192b08, + 0x08192b2b, 0x19192b19, 0x19080808, 0x19192b19, 0x1908082b, 0x19192b19, 0x2b081919, 0x19192b2b, + 0x08080819, 0x192b0808, 0x08081908, 0x192b0808, 0x08190808, 0x192b0808, 0x19080808, 0x192b0808, + 0x19191908, 0x192b0808, 0x192b082b, 0x192b0808, 0x2b08192b, 0x192b0808, 0x2b2b2b19, 0x192b0808, + 0x08080808, 0x192b0819, 0x082b1908, 0x192b082b, 0x19082b2b, 0x192b082b, 0x2b19082b, 0x192b082b, + 0x08080808, 0x192b1908, 0x0819192b, 0x192b1908, 0x08190808, 0x192b1919, 0x19080808, 0x192b1919, + 0x19081919, 0x192b1919, 0x2b2b1908, 0x192b1919, 0x08080819, 0x192b2b08, 0x192b2b2b, 0x192b2b08, + 0x082b1919, 0x192b2b19, 0x0808192b, 0x192b2b2b, 0x19191908, 0x192b2b2b, 0x192b082b, 0x192b2b2b, + 0x08080808, 0x2b080808, 0x0808082b, 0x2b080808, 0x08081919, 0x2b080808, 0x08082b08, 0x2b080808, + 0x08190819, 0x2b080808, 0x08191908, 0x2b080808, 0x082b0808, 0x2b080808, 0x082b2b2b, 0x2b080808, + 0x19080819, 0x2b080808, 0x19081908, 0x2b080808, 0x19190808, 0x2b080808, 0x2b080808, 0x2b080808, + 0x2b08082b, 0x2b080808, 0x2b2b2b08, 0x2b080808, 0x2b2b2b2b, 0x2b080808, 0x08080819, 0x2b080819, + 0x08081908, 0x2b080819, 0x0808192b, 0x2b080819, 0x08190808, 0x2b080819, 0x19080808, 0x2b080819, + 0x19190819, 0x2b080819, 0x19192b19, 0x2b080819, 0x08080808, 0x2b08082b, 0x082b0808, 0x2b08082b, + 0x2b080808, 0x2b08082b, 0x2b08082b, 0x2b08082b, 0x2b2b0808, 0x2b08082b, 0x2b2b2b08, 0x2b08082b, + 0x08080819, 0x2b081908, 0x08081908, 0x2b081908, 0x08190808, 0x2b081908, 0x0819082b, 0x2b081908, + 0x08191919, 0x2b081908, 0x19080808, 0x2b081908, 0x192b0808, 0x2b081908, 0x2b082b19, 0x2b081908, + 0x08080808, 0x2b081919, 0x19081908, 0x2b081919, 0x2b2b1919, 0x2b081919, 0x08192b08, 0x2b08192b, + 0x192b2b2b, 0x2b08192b, 0x08080808, 0x2b082b08, 0x08082b08, 0x2b082b08, 0x082b1919, 0x2b082b08, + 0x19192b2b, 0x2b082b08, 0x2b080808, 0x2b082b08, 0x2b08082b, 0x2b082b08, 0x2b2b2b08, 0x2b082b08, + 0x0808192b, 0x2b082b19, 0x082b082b, 0x2b082b2b, 0x2b080808, 0x2b082b2b, 0x2b082b08, 0x2b082b2b, + 0x2b19192b, 0x2b082b2b, 0x2b2b2b08, 0x2b082b2b, 0x08080819, 0x2b190808, 0x08081908, 0x2b190808, + 0x08190808, 0x2b190808, 0x19080808, 0x2b190808, 0x1919192b, 0x2b190808, 0x2b081908, 0x2b190808, + 0x08080808, 0x2b190819, 0x082b082b, 0x2b190819, 0x192b1908, 0x2b190819, 0x1919192b, 0x2b19082b, + 0x2b082b19, 0x2b19082b, 0x08080808, 0x2b191908, 0x08081919, 0x2b191908, 0x19081908, 0x2b191908, + 0x19190808, 0x2b191908, 0x19192b08, 0x2b191908, 0x082b2b19, 0x2b191919, 0x2b190808, 0x2b191919, + 0x2b19082b, 0x2b191919, 0x19080819, 0x2b19192b, 0x19190819, 0x2b192b08, 0x2b2b192b, 0x2b192b08, + 0x19082b19, 0x2b192b19, 0x08191919, 0x2b192b2b, 0x192b0808, 0x2b192b2b, 0x08080808, 0x2b2b0808, + 0x0808082b, 0x2b2b0808, 0x08082b08, 0x2b2b0808, 0x08082b2b, 0x2b2b0808, 0x082b0808, 0x2b2b0808, + 0x082b2b2b, 0x2b2b0808, 0x2b2b0808, 0x2b2b0808, 0x19190819, 0x2b2b0819, 0x19192b19, 0x2b2b0819, + 0x2b2b192b, 0x2b2b0819, 0x08080808, 0x2b2b082b, 0x0808082b, 0x2b2b082b, 0x08082b08, 0x2b2b082b, + 0x082b2b2b, 0x2b2b082b, 0x2b080808, 0x2b2b082b, 0x2b2b0808, 0x2b2b082b, 0x19080808, 0x2b2b1908, + 0x2b191919, 0x2b2b1908, 0x192b1919, 0x2b2b192b, 0x2b192b08, 0x2b2b192b, 0x08082b2b, 0x2b2b2b08, + 0x082b0808, 0x2b2b2b08, 0x082b082b, 0x2b2b2b08, 0x082b2b08, 0x2b2b2b08, 0x2b2b0808, 0x2b2b2b08, + 0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19, + 0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b +); +#enddecl(IQ2_XS_GRID) + +#decl(IQ2_S_GRID) +const iq2s_grid = array( + 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, + 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, + 0x08192b19, 0x08080808, 0x082b0808, 0x08080808, 0x082b082b, 0x08080808, 0x082b1919, 0x08080808, + 0x082b2b08, 0x08080808, 0x19080819, 0x08080808, 0x19081908, 0x08080808, 0x1908192b, 0x08080808, + 0x19082b19, 0x08080808, 0x19190808, 0x08080808, 0x1919082b, 0x08080808, 0x19191919, 0x08080808, + 0x19192b08, 0x08080808, 0x192b0819, 0x08080808, 0x192b1908, 0x08080808, 0x192b192b, 0x08080808, + 0x192b2b19, 0x08080808, 0x2b080808, 0x08080808, 0x2b08082b, 0x08080808, 0x2b081919, 0x08080808, + 0x2b082b08, 0x08080808, 0x2b190819, 0x08080808, 0x2b191908, 0x08080808, 0x2b2b0808, 0x08080808, + 0x2b2b1919, 0x08080808, 0x2b2b2b2b, 0x08080808, 0x08080819, 0x08080819, 0x08081908, 0x08080819, + 0x0808192b, 0x08080819, 0x08082b19, 0x08080819, 0x08190808, 0x08080819, 0x0819082b, 0x08080819, + 0x08191919, 0x08080819, 0x08192b08, 0x08080819, 0x082b0819, 0x08080819, 0x082b1908, 0x08080819, + 0x19080808, 0x08080819, 0x1908082b, 0x08080819, 0x19081919, 0x08080819, 0x19082b08, 0x08080819, + 0x19190819, 0x08080819, 0x19191908, 0x08080819, 0x1919192b, 0x08080819, 0x19192b19, 0x08080819, + 0x192b0808, 0x08080819, 0x192b1919, 0x08080819, 0x192b2b08, 0x08080819, 0x2b080819, 0x08080819, + 0x2b081908, 0x08080819, 0x2b190808, 0x08080819, 0x2b19082b, 0x08080819, 0x2b191919, 0x08080819, + 0x2b2b0819, 0x08080819, 0x2b2b1908, 0x08080819, 0x08080808, 0x0808082b, 0x0808082b, 0x0808082b, + 0x08081919, 0x0808082b, 0x08082b08, 0x0808082b, 0x08190819, 0x0808082b, 0x08191908, 0x0808082b, + 0x082b0808, 0x0808082b, 0x082b2b2b, 0x0808082b, 0x19080819, 0x0808082b, 0x19081908, 0x0808082b, + 0x1908192b, 0x0808082b, 0x19082b19, 0x0808082b, 0x19190808, 0x0808082b, 0x19191919, 0x0808082b, + 0x2b080808, 0x0808082b, 0x2b081919, 0x0808082b, 0x2b082b2b, 0x0808082b, 0x2b191908, 0x0808082b, + 0x2b2b082b, 0x0808082b, 0x08080819, 0x08081908, 0x08081908, 0x08081908, 0x0808192b, 0x08081908, + 0x08082b19, 0x08081908, 0x08190808, 0x08081908, 0x0819082b, 0x08081908, 0x08191919, 0x08081908, + 0x08192b08, 0x08081908, 0x082b0819, 0x08081908, 0x082b1908, 0x08081908, 0x082b192b, 0x08081908, + 0x082b2b19, 0x08081908, 0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19081919, 0x08081908, + 0x19082b08, 0x08081908, 0x19082b2b, 0x08081908, 0x19190819, 0x08081908, 0x19191908, 0x08081908, + 0x1919192b, 0x08081908, 0x19192b19, 0x08081908, 0x192b0808, 0x08081908, 0x192b082b, 0x08081908, + 0x192b1919, 0x08081908, 0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b08192b, 0x08081908, + 0x2b082b19, 0x08081908, 0x2b190808, 0x08081908, 0x2b191919, 0x08081908, 0x2b192b08, 0x08081908, + 0x2b2b0819, 0x08081908, 0x2b2b1908, 0x08081908, 0x08080808, 0x08081919, 0x0808082b, 0x08081919, + 0x08081919, 0x08081919, 0x08082b08, 0x08081919, 0x08082b2b, 0x08081919, 0x08190819, 0x08081919, + 0x08191908, 0x08081919, 0x0819192b, 0x08081919, 0x08192b19, 0x08081919, 0x082b0808, 0x08081919, + 0x082b1919, 0x08081919, 0x082b2b08, 0x08081919, 0x19080819, 0x08081919, 0x19081908, 0x08081919, + 0x1908192b, 0x08081919, 0x19082b19, 0x08081919, 0x19190808, 0x08081919, 0x1919082b, 0x08081919, + 0x19191919, 0x08081919, 0x19192b08, 0x08081919, 0x192b0819, 0x08081919, 0x192b1908, 0x08081919, + 0x2b080808, 0x08081919, 0x2b08082b, 0x08081919, 0x2b081919, 0x08081919, 0x2b082b08, 0x08081919, + 0x2b190819, 0x08081919, 0x2b191908, 0x08081919, 0x2b2b0808, 0x08081919, 0x08080819, 0x0808192b, + 0x08081908, 0x0808192b, 0x0808192b, 0x0808192b, 0x08082b19, 0x0808192b, 0x08190808, 0x0808192b, + 0x08191919, 0x0808192b, 0x19080808, 0x0808192b, 0x19081919, 0x0808192b, 0x19082b08, 0x0808192b, + 0x19190819, 0x0808192b, 0x19191908, 0x0808192b, 0x192b0808, 0x0808192b, 0x2b080819, 0x0808192b, + 0x2b081908, 0x0808192b, 0x2b190808, 0x0808192b, 0x08080808, 0x08082b08, 0x0808082b, 0x08082b08, + 0x08081919, 0x08082b08, 0x08082b08, 0x08082b08, 0x08190819, 0x08082b08, 0x08191908, 0x08082b08, + 0x0819192b, 0x08082b08, 0x08192b19, 0x08082b08, 0x082b0808, 0x08082b08, 0x082b1919, 0x08082b08, + 0x082b2b2b, 0x08082b08, 0x19080819, 0x08082b08, 0x19081908, 0x08082b08, 0x1908192b, 0x08082b08, + 0x19082b19, 0x08082b08, 0x19190808, 0x08082b08, 0x1919082b, 0x08082b08, 0x19191919, 0x08082b08, + 0x19192b08, 0x08082b08, 0x192b0819, 0x08082b08, 0x192b1908, 0x08082b08, 0x2b080808, 0x08082b08, + 0x2b081919, 0x08082b08, 0x2b191908, 0x08082b08, 0x2b2b2b2b, 0x08082b08, 0x08080819, 0x08082b19, + 0x08081908, 0x08082b19, 0x08190808, 0x08082b19, 0x0819082b, 0x08082b19, 0x08191919, 0x08082b19, + 0x08192b08, 0x08082b19, 0x082b0819, 0x08082b19, 0x19080808, 0x08082b19, 0x19081919, 0x08082b19, + 0x19082b08, 0x08082b19, 0x19190819, 0x08082b19, 0x19191908, 0x08082b19, 0x192b0808, 0x08082b19, + 0x2b080819, 0x08082b19, 0x2b190808, 0x08082b19, 0x08080808, 0x08082b2b, 0x08190819, 0x08082b2b, + 0x08191908, 0x08082b2b, 0x082b082b, 0x08082b2b, 0x082b2b08, 0x08082b2b, 0x082b2b2b, 0x08082b2b, + 0x19190808, 0x08082b2b, 0x2b192b19, 0x08082b2b, 0x08080819, 0x08190808, 0x08081908, 0x08190808, + 0x0808192b, 0x08190808, 0x08082b19, 0x08190808, 0x08190808, 0x08190808, 0x0819082b, 0x08190808, + 0x08191919, 0x08190808, 0x08192b08, 0x08190808, 0x082b0819, 0x08190808, 0x082b1908, 0x08190808, + 0x082b192b, 0x08190808, 0x19080808, 0x08190808, 0x1908082b, 0x08190808, 0x19081919, 0x08190808, + 0x19082b08, 0x08190808, 0x19190819, 0x08190808, 0x19191908, 0x08190808, 0x1919192b, 0x08190808, + 0x19192b19, 0x08190808, 0x192b0808, 0x08190808, 0x192b082b, 0x08190808, 0x192b1919, 0x08190808, + 0x192b2b08, 0x08190808, 0x2b080819, 0x08190808, 0x2b081908, 0x08190808, 0x2b08192b, 0x08190808, + 0x2b190808, 0x08190808, 0x2b191919, 0x08190808, 0x2b192b08, 0x08190808, 0x2b2b0819, 0x08190808, + 0x2b2b1908, 0x08190808, 0x08080808, 0x08190819, 0x0808082b, 0x08190819, 0x08081919, 0x08190819, + 0x08082b08, 0x08190819, 0x08082b2b, 0x08190819, 0x08190819, 0x08190819, 0x08191908, 0x08190819, + 0x0819192b, 0x08190819, 0x08192b19, 0x08190819, 0x082b0808, 0x08190819, 0x082b082b, 0x08190819, + 0x082b1919, 0x08190819, 0x082b2b08, 0x08190819, 0x19080819, 0x08190819, 0x19081908, 0x08190819, + 0x1908192b, 0x08190819, 0x19082b19, 0x08190819, 0x19190808, 0x08190819, 0x1919082b, 0x08190819, + 0x19191919, 0x08190819, 0x19192b08, 0x08190819, 0x192b0819, 0x08190819, 0x192b1908, 0x08190819, + 0x2b080808, 0x08190819, 0x2b08082b, 0x08190819, 0x2b081919, 0x08190819, 0x2b082b08, 0x08190819, + 0x2b190819, 0x08190819, 0x2b191908, 0x08190819, 0x08080819, 0x0819082b, 0x08081908, 0x0819082b, + 0x08082b19, 0x0819082b, 0x08190808, 0x0819082b, 0x08191919, 0x0819082b, 0x082b0819, 0x0819082b, + 0x082b1908, 0x0819082b, 0x19080808, 0x0819082b, 0x19081919, 0x0819082b, 0x19190819, 0x0819082b, + 0x19191908, 0x0819082b, 0x2b080819, 0x0819082b, 0x2b081908, 0x0819082b, 0x2b190808, 0x0819082b, + 0x08080808, 0x08191908, 0x0808082b, 0x08191908, 0x08081919, 0x08191908, 0x08082b08, 0x08191908, + 0x08190819, 0x08191908, 0x08191908, 0x08191908, 0x0819192b, 0x08191908, 0x08192b19, 0x08191908, + 0x082b0808, 0x08191908, 0x082b1919, 0x08191908, 0x082b2b08, 0x08191908, 0x19080819, 0x08191908, + 0x19081908, 0x08191908, 0x1908192b, 0x08191908, 0x19082b19, 0x08191908, 0x19190808, 0x08191908, + 0x1919082b, 0x08191908, 0x19191919, 0x08191908, 0x19192b08, 0x08191908, 0x192b0819, 0x08191908, + 0x192b1908, 0x08191908, 0x2b080808, 0x08191908, 0x2b08082b, 0x08191908, 0x2b081919, 0x08191908, + 0x2b082b08, 0x08191908, 0x2b190819, 0x08191908, 0x2b191908, 0x08191908, 0x2b2b0808, 0x08191908, + 0x08080819, 0x08191919, 0x08081908, 0x08191919, 0x0808192b, 0x08191919, 0x08082b19, 0x08191919, + 0x08190808, 0x08191919, 0x0819082b, 0x08191919, 0x08191919, 0x08191919, 0x08192b08, 0x08191919, + 0x082b0819, 0x08191919, 0x082b1908, 0x08191919, 0x19080808, 0x08191919, 0x1908082b, 0x08191919, + 0x19081919, 0x08191919, 0x19082b08, 0x08191919, 0x19190819, 0x08191919, 0x19191908, 0x08191919, + 0x192b0808, 0x08191919, 0x2b080819, 0x08191919, 0x2b081908, 0x08191919, 0x2b190808, 0x08191919, + 0x08080808, 0x0819192b, 0x08081919, 0x0819192b, 0x08082b08, 0x0819192b, 0x08190819, 0x0819192b, + 0x08191908, 0x0819192b, 0x082b0808, 0x0819192b, 0x19080819, 0x0819192b, 0x19081908, 0x0819192b, + 0x19190808, 0x0819192b, 0x2b080808, 0x0819192b, 0x2b2b2b2b, 0x0819192b, 0x08080819, 0x08192b08, + 0x08081908, 0x08192b08, 0x0808192b, 0x08192b08, 0x08082b19, 0x08192b08, 0x08190808, 0x08192b08, + 0x08191919, 0x08192b08, 0x08192b08, 0x08192b08, 0x082b0819, 0x08192b08, 0x19080808, 0x08192b08, + 0x1908082b, 0x08192b08, 0x19081919, 0x08192b08, 0x19082b08, 0x08192b08, 0x19190819, 0x08192b08, + 0x19191908, 0x08192b08, 0x192b0808, 0x08192b08, 0x2b080819, 0x08192b08, 0x2b081908, 0x08192b08, + 0x08080808, 0x08192b19, 0x0808082b, 0x08192b19, 0x08081919, 0x08192b19, 0x08082b08, 0x08192b19, + 0x08190819, 0x08192b19, 0x08191908, 0x08192b19, 0x082b0808, 0x08192b19, 0x19080819, 0x08192b19, + 0x19081908, 0x08192b19, 0x19190808, 0x08192b19, 0x192b2b19, 0x08192b19, 0x2b2b082b, 0x08192b19, + 0x08081908, 0x08192b2b, 0x08190808, 0x08192b2b, 0x19080808, 0x08192b2b, 0x1919192b, 0x08192b2b, + 0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08081919, 0x082b0808, 0x08082b08, 0x082b0808, + 0x08190819, 0x082b0808, 0x08191908, 0x082b0808, 0x0819192b, 0x082b0808, 0x08192b19, 0x082b0808, + 0x082b0808, 0x082b0808, 0x082b1919, 0x082b0808, 0x082b2b2b, 0x082b0808, 0x19080819, 0x082b0808, + 0x19081908, 0x082b0808, 0x19190808, 0x082b0808, 0x1919082b, 0x082b0808, 0x19191919, 0x082b0808, + 0x192b1908, 0x082b0808, 0x2b080808, 0x082b0808, 0x2b082b2b, 0x082b0808, 0x2b191908, 0x082b0808, + 0x2b2b2b2b, 0x082b0808, 0x08080819, 0x082b0819, 0x08081908, 0x082b0819, 0x08190808, 0x082b0819, + 0x0819082b, 0x082b0819, 0x08191919, 0x082b0819, 0x082b0819, 0x082b0819, 0x19080808, 0x082b0819, + 0x1908082b, 0x082b0819, 0x19081919, 0x082b0819, 0x19190819, 0x082b0819, 0x19191908, 0x082b0819, + 0x192b0808, 0x082b0819, 0x2b080819, 0x082b0819, 0x2b081908, 0x082b0819, 0x2b190808, 0x082b0819, + 0x08080808, 0x082b082b, 0x08082b2b, 0x082b082b, 0x082b082b, 0x082b082b, 0x082b2b08, 0x082b082b, + 0x082b2b2b, 0x082b082b, 0x19081908, 0x082b082b, 0x19190808, 0x082b082b, 0x2b082b08, 0x082b082b, + 0x2b082b2b, 0x082b082b, 0x2b2b2b08, 0x082b082b, 0x08080819, 0x082b1908, 0x08081908, 0x082b1908, + 0x0808192b, 0x082b1908, 0x08082b19, 0x082b1908, 0x08190808, 0x082b1908, 0x08191919, 0x082b1908, + 0x08192b08, 0x082b1908, 0x082b0819, 0x082b1908, 0x082b1908, 0x082b1908, 0x19080808, 0x082b1908, + 0x1908082b, 0x082b1908, 0x19081919, 0x082b1908, 0x19082b08, 0x082b1908, 0x19190819, 0x082b1908, + 0x19191908, 0x082b1908, 0x192b0808, 0x082b1908, 0x2b080819, 0x082b1908, 0x2b081908, 0x082b1908, + 0x2b190808, 0x082b1908, 0x08080808, 0x082b1919, 0x08081919, 0x082b1919, 0x08082b08, 0x082b1919, + 0x08190819, 0x082b1919, 0x08191908, 0x082b1919, 0x082b0808, 0x082b1919, 0x19080819, 0x082b1919, + 0x19081908, 0x082b1919, 0x19190808, 0x082b1919, 0x192b192b, 0x082b1919, 0x2b080808, 0x082b1919, + 0x08080819, 0x082b192b, 0x08081908, 0x082b192b, 0x08190808, 0x082b192b, 0x19080808, 0x082b192b, + 0x19192b19, 0x082b192b, 0x08080808, 0x082b2b08, 0x08081919, 0x082b2b08, 0x08190819, 0x082b2b08, + 0x08191908, 0x082b2b08, 0x19080819, 0x082b2b08, 0x19081908, 0x082b2b08, 0x19190808, 0x082b2b08, + 0x2b082b2b, 0x082b2b08, 0x2b2b2b2b, 0x082b2b08, 0x08080819, 0x082b2b19, 0x08081908, 0x082b2b19, + 0x08190808, 0x082b2b19, 0x2b191919, 0x082b2b19, 0x08082b2b, 0x082b2b2b, 0x082b082b, 0x082b2b2b, + 0x192b1908, 0x082b2b2b, 0x2b082b08, 0x082b2b2b, 0x2b082b2b, 0x082b2b2b, 0x08080819, 0x19080808, + 0x08081908, 0x19080808, 0x0808192b, 0x19080808, 0x08082b19, 0x19080808, 0x08190808, 0x19080808, + 0x0819082b, 0x19080808, 0x08191919, 0x19080808, 0x08192b08, 0x19080808, 0x08192b2b, 0x19080808, + 0x082b0819, 0x19080808, 0x082b1908, 0x19080808, 0x082b192b, 0x19080808, 0x19080808, 0x19080808, + 0x1908082b, 0x19080808, 0x19081919, 0x19080808, 0x19082b08, 0x19080808, 0x19082b2b, 0x19080808, + 0x19190819, 0x19080808, 0x19191908, 0x19080808, 0x1919192b, 0x19080808, 0x19192b19, 0x19080808, + 0x192b0808, 0x19080808, 0x192b082b, 0x19080808, 0x192b1919, 0x19080808, 0x2b080819, 0x19080808, + 0x2b081908, 0x19080808, 0x2b190808, 0x19080808, 0x2b191919, 0x19080808, 0x2b192b08, 0x19080808, + 0x2b2b0819, 0x19080808, 0x2b2b1908, 0x19080808, 0x08080808, 0x19080819, 0x0808082b, 0x19080819, + 0x08081919, 0x19080819, 0x08082b08, 0x19080819, 0x08190819, 0x19080819, 0x08191908, 0x19080819, + 0x0819192b, 0x19080819, 0x08192b19, 0x19080819, 0x082b0808, 0x19080819, 0x082b082b, 0x19080819, + 0x082b1919, 0x19080819, 0x19080819, 0x19080819, 0x19081908, 0x19080819, 0x1908192b, 0x19080819, + 0x19082b19, 0x19080819, 0x19190808, 0x19080819, 0x1919082b, 0x19080819, 0x19191919, 0x19080819, + 0x19192b08, 0x19080819, 0x192b0819, 0x19080819, 0x192b1908, 0x19080819, 0x2b080808, 0x19080819, + 0x2b08082b, 0x19080819, 0x2b081919, 0x19080819, 0x2b082b08, 0x19080819, 0x2b190819, 0x19080819, + 0x2b191908, 0x19080819, 0x2b2b0808, 0x19080819, 0x08080819, 0x1908082b, 0x08081908, 0x1908082b, + 0x08190808, 0x1908082b, 0x0819082b, 0x1908082b, 0x08191919, 0x1908082b, 0x08192b08, 0x1908082b, + 0x082b1908, 0x1908082b, 0x19080808, 0x1908082b, 0x19081919, 0x1908082b, 0x19082b08, 0x1908082b, + 0x19190819, 0x1908082b, 0x19191908, 0x1908082b, 0x192b0808, 0x1908082b, 0x2b080819, 0x1908082b, + 0x2b081908, 0x1908082b, 0x08080808, 0x19081908, 0x0808082b, 0x19081908, 0x08081919, 0x19081908, + 0x08082b08, 0x19081908, 0x08082b2b, 0x19081908, 0x08190819, 0x19081908, 0x08191908, 0x19081908, + 0x0819192b, 0x19081908, 0x08192b19, 0x19081908, 0x082b0808, 0x19081908, 0x082b082b, 0x19081908, + 0x082b1919, 0x19081908, 0x082b2b08, 0x19081908, 0x19080819, 0x19081908, 0x19081908, 0x19081908, + 0x1908192b, 0x19081908, 0x19082b19, 0x19081908, 0x19190808, 0x19081908, 0x1919082b, 0x19081908, + 0x19191919, 0x19081908, 0x19192b08, 0x19081908, 0x192b0819, 0x19081908, 0x192b1908, 0x19081908, + 0x2b080808, 0x19081908, 0x2b08082b, 0x19081908, 0x2b081919, 0x19081908, 0x2b082b08, 0x19081908, + 0x2b190819, 0x19081908, 0x2b191908, 0x19081908, 0x2b2b0808, 0x19081908, 0x08080819, 0x19081919, + 0x08081908, 0x19081919, 0x0808192b, 0x19081919, 0x08082b19, 0x19081919, 0x08190808, 0x19081919, + 0x0819082b, 0x19081919, 0x08191919, 0x19081919, 0x08192b08, 0x19081919, 0x082b0819, 0x19081919, + 0x082b1908, 0x19081919, 0x19080808, 0x19081919, 0x1908082b, 0x19081919, 0x19081919, 0x19081919, + 0x19082b08, 0x19081919, 0x19190819, 0x19081919, 0x19191908, 0x19081919, 0x192b0808, 0x19081919, + 0x192b2b2b, 0x19081919, 0x2b080819, 0x19081919, 0x2b081908, 0x19081919, 0x2b190808, 0x19081919, + 0x08080808, 0x1908192b, 0x0808082b, 0x1908192b, 0x08081919, 0x1908192b, 0x08082b08, 0x1908192b, + 0x08190819, 0x1908192b, 0x08191908, 0x1908192b, 0x082b0808, 0x1908192b, 0x19080819, 0x1908192b, + 0x19081908, 0x1908192b, 0x19190808, 0x1908192b, 0x2b080808, 0x1908192b, 0x2b2b1919, 0x1908192b, + 0x08080819, 0x19082b08, 0x08081908, 0x19082b08, 0x08082b19, 0x19082b08, 0x08190808, 0x19082b08, + 0x0819082b, 0x19082b08, 0x08191919, 0x19082b08, 0x08192b08, 0x19082b08, 0x082b0819, 0x19082b08, + 0x082b1908, 0x19082b08, 0x19080808, 0x19082b08, 0x1908082b, 0x19082b08, 0x19081919, 0x19082b08, + 0x19082b08, 0x19082b08, 0x19190819, 0x19082b08, 0x19191908, 0x19082b08, 0x192b0808, 0x19082b08, + 0x2b081908, 0x19082b08, 0x2b190808, 0x19082b08, 0x08080808, 0x19082b19, 0x0808082b, 0x19082b19, + 0x08081919, 0x19082b19, 0x08082b08, 0x19082b19, 0x08190819, 0x19082b19, 0x08191908, 0x19082b19, + 0x082b0808, 0x19082b19, 0x19080819, 0x19082b19, 0x19081908, 0x19082b19, 0x19190808, 0x19082b19, + 0x2b080808, 0x19082b19, 0x2b19192b, 0x19082b19, 0x08080819, 0x19082b2b, 0x08081908, 0x19082b2b, + 0x08190808, 0x19082b2b, 0x19080808, 0x19082b2b, 0x08080808, 0x19190808, 0x0808082b, 0x19190808, + 0x08081919, 0x19190808, 0x08082b08, 0x19190808, 0x08190819, 0x19190808, 0x08191908, 0x19190808, + 0x0819192b, 0x19190808, 0x08192b19, 0x19190808, 0x082b0808, 0x19190808, 0x082b082b, 0x19190808, + 0x082b1919, 0x19190808, 0x082b2b08, 0x19190808, 0x19080819, 0x19190808, 0x19081908, 0x19190808, + 0x1908192b, 0x19190808, 0x19082b19, 0x19190808, 0x19190808, 0x19190808, 0x1919082b, 0x19190808, + 0x19191919, 0x19190808, 0x19192b08, 0x19190808, 0x192b0819, 0x19190808, 0x192b1908, 0x19190808, + 0x2b080808, 0x19190808, 0x2b08082b, 0x19190808, 0x2b081919, 0x19190808, 0x2b082b08, 0x19190808, + 0x2b190819, 0x19190808, 0x2b191908, 0x19190808, 0x08080819, 0x19190819, 0x08081908, 0x19190819, + 0x0808192b, 0x19190819, 0x08082b19, 0x19190819, 0x08190808, 0x19190819, 0x0819082b, 0x19190819, + 0x08191919, 0x19190819, 0x08192b08, 0x19190819, 0x082b0819, 0x19190819, 0x082b1908, 0x19190819, + 0x19080808, 0x19190819, 0x1908082b, 0x19190819, 0x19081919, 0x19190819, 0x19082b08, 0x19190819, + 0x19190819, 0x19190819, 0x19191908, 0x19190819, 0x192b0808, 0x19190819, 0x2b080819, 0x19190819, + 0x2b081908, 0x19190819, 0x2b190808, 0x19190819, 0x08080808, 0x1919082b, 0x08081919, 0x1919082b, + 0x08082b08, 0x1919082b, 0x08190819, 0x1919082b, 0x08191908, 0x1919082b, 0x082b0808, 0x1919082b, + 0x19080819, 0x1919082b, 0x19081908, 0x1919082b, 0x19190808, 0x1919082b, 0x192b2b19, 0x1919082b, + 0x2b080808, 0x1919082b, 0x08080819, 0x19191908, 0x08081908, 0x19191908, 0x0808192b, 0x19191908, + 0x08082b19, 0x19191908, 0x08190808, 0x19191908, 0x0819082b, 0x19191908, 0x08191919, 0x19191908, + 0x08192b08, 0x19191908, 0x082b0819, 0x19191908, 0x082b1908, 0x19191908, 0x19080808, 0x19191908, + 0x1908082b, 0x19191908, 0x19081919, 0x19191908, 0x19082b08, 0x19191908, 0x19190819, 0x19191908, + 0x19191908, 0x19191908, 0x192b0808, 0x19191908, 0x2b080819, 0x19191908, 0x2b081908, 0x19191908, + 0x2b190808, 0x19191908, 0x08080808, 0x19191919, 0x0808082b, 0x19191919, 0x08081919, 0x19191919, + 0x08082b08, 0x19191919, 0x08190819, 0x19191919, 0x08191908, 0x19191919, 0x082b0808, 0x19191919, + 0x19080819, 0x19191919, 0x19081908, 0x19191919, 0x19190808, 0x19191919, 0x2b080808, 0x19191919, + 0x08080819, 0x1919192b, 0x08081908, 0x1919192b, 0x08190808, 0x1919192b, 0x082b192b, 0x1919192b, + 0x19080808, 0x1919192b, 0x08080808, 0x19192b08, 0x0808082b, 0x19192b08, 0x08081919, 0x19192b08, + 0x08082b08, 0x19192b08, 0x08190819, 0x19192b08, 0x08191908, 0x19192b08, 0x082b0808, 0x19192b08, + 0x19080819, 0x19192b08, 0x19081908, 0x19192b08, 0x19190808, 0x19192b08, 0x19192b2b, 0x19192b08, + 0x2b080808, 0x19192b08, 0x08080819, 0x19192b19, 0x08081908, 0x19192b19, 0x08190808, 0x19192b19, + 0x19080808, 0x19192b19, 0x08080808, 0x19192b2b, 0x08192b19, 0x19192b2b, 0x2b081919, 0x19192b2b, + 0x2b2b2b08, 0x19192b2b, 0x08080819, 0x192b0808, 0x08081908, 0x192b0808, 0x0808192b, 0x192b0808, + 0x08190808, 0x192b0808, 0x0819082b, 0x192b0808, 0x08191919, 0x192b0808, 0x08192b08, 0x192b0808, + 0x082b0819, 0x192b0808, 0x082b1908, 0x192b0808, 0x19080808, 0x192b0808, 0x19081919, 0x192b0808, + 0x19082b08, 0x192b0808, 0x19190819, 0x192b0808, 0x19191908, 0x192b0808, 0x192b0808, 0x192b0808, + 0x2b081908, 0x192b0808, 0x2b190808, 0x192b0808, 0x08080808, 0x192b0819, 0x0808082b, 0x192b0819, + 0x08081919, 0x192b0819, 0x08082b08, 0x192b0819, 0x08190819, 0x192b0819, 0x08191908, 0x192b0819, + 0x082b0808, 0x192b0819, 0x19080819, 0x192b0819, 0x19081908, 0x192b0819, 0x19190808, 0x192b0819, + 0x2b080808, 0x192b0819, 0x2b192b19, 0x192b0819, 0x08081908, 0x192b082b, 0x08190808, 0x192b082b, + 0x19080808, 0x192b082b, 0x1919192b, 0x192b082b, 0x2b2b0819, 0x192b082b, 0x08080808, 0x192b1908, + 0x08081919, 0x192b1908, 0x08082b08, 0x192b1908, 0x08190819, 0x192b1908, 0x08191908, 0x192b1908, + 0x082b0808, 0x192b1908, 0x19080819, 0x192b1908, 0x19081908, 0x192b1908, 0x19190808, 0x192b1908, + 0x2b080808, 0x192b1908, 0x08080819, 0x192b1919, 0x08081908, 0x192b1919, 0x08190808, 0x192b1919, + 0x19080808, 0x192b1919, 0x19082b2b, 0x192b1919, 0x192b2b08, 0x192b1919, 0x2b19082b, 0x192b1919, + 0x08080808, 0x192b192b, 0x2b191908, 0x192b192b, 0x08080819, 0x192b2b08, 0x08081908, 0x192b2b08, + 0x08190808, 0x192b2b08, 0x192b1919, 0x192b2b08, 0x2b192b08, 0x192b2b08, 0x08080808, 0x192b2b19, + 0x082b2b2b, 0x192b2b19, 0x1908082b, 0x192b2b2b, 0x2b2b0819, 0x192b2b2b, 0x08080808, 0x2b080808, + 0x0808082b, 0x2b080808, 0x08081919, 0x2b080808, 0x08082b08, 0x2b080808, 0x08190819, 0x2b080808, + 0x08191908, 0x2b080808, 0x08192b19, 0x2b080808, 0x082b0808, 0x2b080808, 0x082b1919, 0x2b080808, + 0x19080819, 0x2b080808, 0x19081908, 0x2b080808, 0x19190808, 0x2b080808, 0x1919082b, 0x2b080808, + 0x19191919, 0x2b080808, 0x19192b08, 0x2b080808, 0x192b0819, 0x2b080808, 0x2b080808, 0x2b080808, + 0x2b081919, 0x2b080808, 0x2b190819, 0x2b080808, 0x2b191908, 0x2b080808, 0x08080819, 0x2b080819, + 0x08081908, 0x2b080819, 0x08082b19, 0x2b080819, 0x08190808, 0x2b080819, 0x0819082b, 0x2b080819, + 0x08191919, 0x2b080819, 0x08192b08, 0x2b080819, 0x082b0819, 0x2b080819, 0x082b1908, 0x2b080819, + 0x19080808, 0x2b080819, 0x1908082b, 0x2b080819, 0x19081919, 0x2b080819, 0x19082b08, 0x2b080819, + 0x19190819, 0x2b080819, 0x19191908, 0x2b080819, 0x2b080819, 0x2b080819, 0x2b081908, 0x2b080819, + 0x2b190808, 0x2b080819, 0x2b2b2b19, 0x2b080819, 0x08080808, 0x2b08082b, 0x08081919, 0x2b08082b, + 0x08082b2b, 0x2b08082b, 0x08190819, 0x2b08082b, 0x08191908, 0x2b08082b, 0x19080819, 0x2b08082b, + 0x19081908, 0x2b08082b, 0x19190808, 0x2b08082b, 0x08080819, 0x2b081908, 0x08081908, 0x2b081908, + 0x0808192b, 0x2b081908, 0x08082b19, 0x2b081908, 0x08190808, 0x2b081908, 0x0819082b, 0x2b081908, + 0x08191919, 0x2b081908, 0x08192b08, 0x2b081908, 0x082b0819, 0x2b081908, 0x19080808, 0x2b081908, + 0x1908082b, 0x2b081908, 0x19081919, 0x2b081908, 0x19082b08, 0x2b081908, 0x19190819, 0x2b081908, + 0x19191908, 0x2b081908, 0x192b0808, 0x2b081908, 0x2b080819, 0x2b081908, 0x2b081908, 0x2b081908, + 0x2b190808, 0x2b081908, 0x08080808, 0x2b081919, 0x0808082b, 0x2b081919, 0x08081919, 0x2b081919, + 0x08082b08, 0x2b081919, 0x08190819, 0x2b081919, 0x08191908, 0x2b081919, 0x082b0808, 0x2b081919, + 0x19080819, 0x2b081919, 0x19081908, 0x2b081919, 0x19190808, 0x2b081919, 0x2b080808, 0x2b081919, + 0x2b082b2b, 0x2b081919, 0x08080819, 0x2b08192b, 0x08081908, 0x2b08192b, 0x08190808, 0x2b08192b, + 0x082b2b19, 0x2b08192b, 0x19080808, 0x2b08192b, 0x08080808, 0x2b082b08, 0x08081919, 0x2b082b08, + 0x08190819, 0x2b082b08, 0x08191908, 0x2b082b08, 0x19080819, 0x2b082b08, 0x19081908, 0x2b082b08, + 0x19190808, 0x2b082b08, 0x2b2b082b, 0x2b082b08, 0x08080819, 0x2b082b19, 0x08081908, 0x2b082b19, + 0x19080808, 0x2b082b19, 0x192b1919, 0x2b082b19, 0x082b082b, 0x2b082b2b, 0x19192b08, 0x2b082b2b, + 0x19192b2b, 0x2b082b2b, 0x2b08082b, 0x2b082b2b, 0x2b2b082b, 0x2b082b2b, 0x08080819, 0x2b190808, + 0x08081908, 0x2b190808, 0x08082b19, 0x2b190808, 0x08190808, 0x2b190808, 0x0819082b, 0x2b190808, + 0x08191919, 0x2b190808, 0x08192b08, 0x2b190808, 0x082b1908, 0x2b190808, 0x19080808, 0x2b190808, + 0x1908082b, 0x2b190808, 0x19081919, 0x2b190808, 0x19082b08, 0x2b190808, 0x19190819, 0x2b190808, + 0x19191908, 0x2b190808, 0x192b0808, 0x2b190808, 0x2b080819, 0x2b190808, 0x2b081908, 0x2b190808, + 0x2b190808, 0x2b190808, 0x08080808, 0x2b190819, 0x08081919, 0x2b190819, 0x08190819, 0x2b190819, + 0x08191908, 0x2b190819, 0x19080819, 0x2b190819, 0x19081908, 0x2b190819, 0x19190808, 0x2b190819, + 0x19192b2b, 0x2b190819, 0x08080819, 0x2b19082b, 0x08081908, 0x2b19082b, 0x08190808, 0x2b19082b, + 0x19080808, 0x2b19082b, 0x2b2b192b, 0x2b19082b, 0x08080808, 0x2b191908, 0x0808082b, 0x2b191908, + 0x08081919, 0x2b191908, 0x08082b08, 0x2b191908, 0x08190819, 0x2b191908, 0x08191908, 0x2b191908, + 0x082b0808, 0x2b191908, 0x19080819, 0x2b191908, 0x19081908, 0x2b191908, 0x19190808, 0x2b191908, + 0x2b080808, 0x2b191908, 0x2b19192b, 0x2b191908, 0x08080819, 0x2b191919, 0x08081908, 0x2b191919, + 0x08190808, 0x2b191919, 0x19080808, 0x2b191919, 0x2b192b08, 0x2b191919, 0x2b2b0819, 0x2b191919, + 0x08080808, 0x2b19192b, 0x1908192b, 0x2b19192b, 0x192b1908, 0x2b19192b, 0x08080819, 0x2b192b08, + 0x08081908, 0x2b192b08, 0x08190808, 0x2b192b08, 0x082b192b, 0x2b192b08, 0x19080808, 0x2b192b08, + 0x2b2b2b19, 0x2b192b08, 0x08080808, 0x2b192b19, 0x19082b19, 0x2b192b19, 0x1919082b, 0x2b192b19, + 0x2b190808, 0x2b192b2b, 0x08080808, 0x2b2b0808, 0x08081919, 0x2b2b0808, 0x08082b2b, 0x2b2b0808, + 0x08191908, 0x2b2b0808, 0x082b082b, 0x2b2b0808, 0x082b2b2b, 0x2b2b0808, 0x19080819, 0x2b2b0808, + 0x19081908, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b2b082b, 0x2b2b0808, 0x2b2b2b2b, 0x2b2b0808, + 0x19080808, 0x2b2b0819, 0x192b1919, 0x2b2b0819, 0x0808082b, 0x2b2b082b, 0x08082b2b, 0x2b2b082b, + 0x082b082b, 0x2b2b082b, 0x082b2b08, 0x2b2b082b, 0x082b2b2b, 0x2b2b082b, 0x2b08082b, 0x2b2b082b, + 0x2b082b08, 0x2b2b082b, 0x2b082b2b, 0x2b2b082b, 0x2b2b2b08, 0x2b2b082b, 0x08080819, 0x2b2b1908, + 0x08081908, 0x2b2b1908, 0x08190808, 0x2b2b1908, 0x19080808, 0x2b2b1908, 0x2b082b19, 0x2b2b1908, + 0x2b2b1908, 0x2b2b1908, 0x08080808, 0x2b2b1919, 0x08192b19, 0x2b2b1919, 0x19190819, 0x2b2b192b, + 0x08082b2b, 0x2b2b2b08, 0x082b2b08, 0x2b2b2b08, 0x2b2b082b, 0x2b2b2b08, 0x19191908, 0x2b2b2b19, + 0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b, + 0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b +); +#enddecl(IQ2_S_GRID) + +#decl(IQ3_XSS_GRID) + +const iq3xxs_grid = array( + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04 +); +#enddecl(IQ3_XSS_GRID) + +#decl(IQ3_S_GRID) + +const iq3s_grid = array( + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101 +); +#enddecl(IQ3_S_GRID) + +#decl(IQ1_GRID) + +const IQ1_DELTA: f32 = 0.125; + +const iq1_grid = array( + 0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01, + 0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4, + 0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41, + 0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f, + 0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334, + 0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f, + 0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040, + 0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f, + 0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5, + 0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3, + 0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff, + 0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570, + 0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f, + 0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf, + 0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f, + 0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07, + 0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc, + 0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374, + 0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0, + 0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001, + 0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043, + 0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc, + 0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117, + 0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f, + 0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5, + 0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474, + 0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d, + 0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd, + 0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50, + 0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10, + 0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30, + 0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1, + 0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c, + 0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074, + 0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134, + 0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7, + 0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3, + 0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450, + 0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577, + 0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c, + 0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5, + 0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c, + 0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00, + 0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300, + 0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc, + 0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034, + 0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077, + 0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5, + 0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117, + 0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f, + 0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5, + 0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404, + 0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1, + 0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd, + 0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71, + 0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7, + 0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00, + 0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44, + 0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00, + 0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0, + 0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303, + 0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343, + 0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd, + 0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031, + 0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011, + 0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c, + 0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4, + 0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c, + 0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174, + 0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7, + 0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d, + 0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4, + 0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c, + 0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7, + 0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510, + 0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33, + 0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4, + 0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73, + 0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f, + 0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337, + 0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343, + 0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030, + 0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075, + 0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4, + 0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170, + 0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705, + 0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c, + 0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c, + 0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514, + 0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c, + 0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3, + 0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70, + 0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03, + 0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c, + 0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c, + 0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074, + 0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104, + 0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7, + 0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757, + 0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c, + 0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c, + 0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4, + 0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc, + 0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03, + 0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc, + 0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54, + 0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f, + 0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf, + 0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c, + 0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c, + 0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4, + 0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174, + 0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700, + 0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7, + 0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d, + 0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531, + 0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf, + 0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57, + 0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13, + 0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01, + 0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f, + 0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7, + 0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074, + 0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107, + 0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd, + 0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0, + 0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7, + 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 +); + +#enddecl(IQ1_GRID) + +#decl(IQ4_GRID) + +const kvalues_iq4nl = array( + -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 +); + +#enddecl(IQ4_GRID) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 962dcd6b170..d9dfd7d6f4f 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -1,35 +1,124 @@ import os +import re +import ast import argparse -def escape_triple_quotes(wgsl): - # Simple defense in case of embedded """ - return wgsl.replace('"""', '\\"""') +def extract_block(text, name): + pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)' + match = re.search(pattern, text, re.DOTALL) + if not match: + raise ValueError(f"Missing block: {name}") + return match.group(1).strip() -def to_cpp_string_literal(varname, content): - return f'const char* wgsl_{varname} = R"({content})";\n' +def parse_decls(decls_text): + decls = {} + for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL): + decls[name.strip()] = code.strip() + return decls + + +def replace_placeholders(shader_text, replacements): + for key, val in replacements.items(): + # Match {{KEY}} literally, where KEY is escaped + pattern = r'{{\s*' + re.escape(key) + r'\s*}}' + shader_text = re.sub(pattern, str(val), shader_text) + return shader_text + + +def expand_includes(shader, input_dir): + """ + Replace #include "file" lines in the text with the contents of that file. + Searches for files relative to input_dir. + """ + include_pattern = re.compile(r'^\s*#include\s+"([^"]+)"\s*$', re.MULTILINE) + + def replacer(match): + fname = match.group(1) + file_path = os.path.join(input_dir, fname) + if not os.path.exists(file_path): + raise FileNotFoundError(f"Included file not found: {file_path}") + with open(file_path, "r", encoding="utf-8") as f: + included_code = f.read() + # Recursively expand includes inside the included file + return expand_includes(included_code, input_dir) + + return include_pattern.sub(replacer, shader) + + +def write_shader(shader_name, shader_code, output_dir, outfile): + if output_dir: + wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl") + with open(wgsl_filename, "w", encoding="utf-8") as f_out: + f_out.write(shader_code) + outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n') + + +def generate_variants(fname, input_dir, output_dir, outfile): + shader_path = os.path.join(input_dir, fname) + shader_base_name = fname.split(".")[0] + + with open(shader_path, "r", encoding="utf-8") as f: + text = f.read() + + try: + variants = ast.literal_eval(extract_block(text, "VARIANTS")) + except ValueError: + write_shader(shader_base_name, text, output_dir, outfile) + else: + try: + decls_map = parse_decls(extract_block(text, "DECLS")) + except ValueError: + decls_map = {} + + with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f: + common_decls = f.read() + decls_map.update(parse_decls(common_decls)) + + shader_template = extract_block(text, "SHADER") + for variant in variants: + if "DECLS" in variant: + decls = variant["DECLS"] + else: + decls = [] + decls_code = "" + for key in decls: + if key not in decls_map: + raise ValueError(f"DECLS key '{key}' not found.") + decls_code += decls_map[key] + "\n\n" + + shader_variant = replace_placeholders(shader_template, variant["REPLS"]) + final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant) + final_shader = expand_includes(final_shader, input_dir) + + if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: + output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) + elif "TYPE_SUFFIX" in variant["REPLS"]: + output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE_SUFFIX"] + elif "TYPE" in variant["REPLS"]: + output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] + else: + output_name = shader_base_name + write_shader(output_name, final_shader, output_dir, outfile) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--input', required=True) - parser.add_argument('--output', required=True) + parser.add_argument("--input_dir", required=True) + parser.add_argument("--output_file", required=True) + parser.add_argument("--output_dir") args = parser.parse_args() - with open(args.output, 'w', encoding='utf-8') as out: - out.write("// Auto-generated shader embedding \n\n") - for fname in sorted(os.listdir(args.input)): - if not fname.endswith('.wgsl'): - continue - shader_path = os.path.join(args.input, fname) - varname = os.path.splitext(fname)[0] - with open(shader_path, 'r', encoding='utf-8') as f: - content = f.read() - content = escape_triple_quotes(content) - out.write(to_cpp_string_literal(varname, content)) - out.write('\n') - - -if __name__ == '__main__': + if args.output_dir: + os.makedirs(args.output_dir, exist_ok=True) + + with open(args.output_file, "w", encoding="utf-8") as out: + out.write("// Auto-generated shader embedding\n\n") + for fname in sorted(os.listdir(args.input_dir)): + if fname.endswith(".wgsl"): + generate_variants(fname, args.input_dir, args.output_dir, out) + + +if __name__ == "__main__": main() diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl new file mode 100644 index 00000000000..e3fe311b26b --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl @@ -0,0 +1,874 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "vec4", + "TYPE_SUFFIX": "f32_vec", + "DST_TYPE": "vec4", + "BLOCK_SIZE": 4 + }, + "DECLS": ["F32_VEC"] + }, + { + "REPLS": { + "TYPE" : "f32", + "DST_TYPE": "f32", + "BLOCK_SIZE": 1 + }, + "DECLS": ["F32"] + }, + { + "REPLS": { + "TYPE" : "f16", + "DST_TYPE": "f32", + "BLOCK_SIZE": 1 + }, + "DECLS": ["F16"] + }, + { + "REPLS": { + "TYPE" : "i32", + "DST_TYPE": "i32", + "BLOCK_SIZE": 1 + }, + "DECLS": ["I32"] + }, + { + "REPLS": { + "TYPE" : "q4_0", + "DST_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"] + }, + { + "REPLS": { + "TYPE" : "q4_1", + "DST_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"] + }, + { + "REPLS": { + "TYPE" : "q5_0", + "DST_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"] + }, + { + "REPLS": { + "TYPE" : "q5_1", + "DST_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"] + }, + { + "REPLS": { + "TYPE" : "q8_0", + "DST_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"] + }, + { + "REPLS": { + "TYPE" : "q2_k", + "DST_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"] + }, + { + "REPLS": { + "TYPE" : "q3_k", + "DST_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"] + }, + { + "REPLS": { + "TYPE" : "q4_k", + "DST_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"] + }, + { + "REPLS": { + "TYPE" : "q5_k", + "DST_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"] + }, + { + "REPLS": { + "TYPE" : "q6_k", + "DST_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"] + }, + { + "REPLS": { + "TYPE" : "iq2_xxs", + "DST_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"] + }, + { + "REPLS": { + "TYPE" : "iq2_xs", + "DST_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"] + }, + { + "REPLS": { + "TYPE": "iq2_s", + "DST_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"] + }, + { + "REPLS": { + "TYPE": "iq3_xxs", + "DST_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"] + }, + { + "REPLS": { + "TYPE": "iq3_s", + "DST_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"] + }, + { + "REPLS": { + "TYPE": "iq1_s", + "DST_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"] + }, + { + "REPLS": { + "TYPE": "iq1_m", + "DST_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"] + }, + { + "REPLS": { + "TYPE": "iq4_nl", + "DST_TYPE": "f32", + "BLOCK_SIZE": 32, + }, + "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"] + }, + { + "REPLS": { + "TYPE": "iq4_xs", + "DST_TYPE": "f32", + "BLOCK_SIZE": 256, + }, + "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(F32_VEC) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset]; +} +#enddecl(F32_VEC) + +#decl(F32) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + dst[dst_base + offset] = src[src_base + offset]; +} +#enddecl(F32) + +#decl(F16) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + dst[dst_base + offset] = f32(src[src_base + offset]); +} +#enddecl(F16) + +#decl(I32) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + dst[dst_base + offset] = src[src_base + offset]; +} +#enddecl(I32) + +#decl(Q4_0) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_q4_0 = src[src_base + offset]; + let d = f32(block_q4_0.d); + for (var j: u32 = 0; j < 4; j++) { + let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; + let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; + dst[dst_offset] = q_lo; + dst[dst_offset + 16] = q_hi; + } + } +} +#enddecl(Q4_0) + +#decl(Q4_1) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_q4_1 = src[src_base + offset]; + let d = f32(block_q4_1.d); + let m = f32(block_q4_1.m); + for (var j: u32 = 0; j < 4; j++) { + let q_packed = block_q4_1.qs[j]; + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32((q_byte >> 4) & 0xF) * d + m; + let q_lo = f32(q_byte & 0xF) * d + m; + let dst_offset = dst_base + offset * 32 + j * 4 + k; + dst[dst_offset] = q_lo; + dst[dst_offset + 16] = q_hi; + } + } +} +#enddecl(Q4_1) + +#decl(Q5_0) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_q5_0 = src[src_base + offset]; + let d = f32(block_q5_0.d); + let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + for (var j: u32 = 0; j < 4; j++) { + let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; + let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; + let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; + dst[dst_offset] = q_lo; + dst[dst_offset + 16] = q_hi; + } + } +} + +#enddecl(Q5_0) + +#decl(Q5_1) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_q5_1 = src[src_base + offset]; + let d = f32(block_q5_1.d); + let m = f32(block_q5_1.m); + for (var j: u32 = 0; j < 4; j++) { + let q_packed = block_q5_1.qs[j]; + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10; + let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m; + let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10; + let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m; + let dst_offset = dst_base + offset * 32 + j * 4 + k; + dst[dst_offset] = q_lo; + dst[dst_offset + 16] = q_hi; + } + } +} +#enddecl(Q5_1) + +#decl(Q8_0) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_q8_0 = src[src_base + offset]; + let d = f32(block_q8_0.d); + for (var j: u32 = 0; j < 8; j++) { + let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; + dst[dst_offset] = q_val; + } + } +} +#enddecl(Q8_0) + +#decl(Q2_K) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var dst_i = dst_base + offset * 256; + var is: u32 = 0; + // 2 halves of the block (128 elements each) + for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { + // 4 groups (each group has 2 blocks of 16 elements) + for (var shift: u32 = 0; shift < 8; shift += 2) { + // 2 blocks + for (var k: u32 = 0; k < 32; k += 16) { + let sc = get_byte(block.scales[is / 4], is % 4); + is++; + let dl = d * f32(sc & 0xF); + let ml = m * f32(sc >> 4); + for (var l: u32 = 0u; l < 16; l++) { + let q_idx = q_b_idx + k + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qs_val = (q_byte >> shift) & 3; + dst[dst_i] = (f32(qs_val) * dl - ml); + dst_i++; + } + } + } + } +} +#enddecl(Q2_K) + +#decl(Q3_K) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + + // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, + // and 2-bits from the last 4 bytes + let kmask1: u32 = 0x03030303; + let kmask2: u32 = 0x0f0f0f0f; + var scale_vals: array; + for (var i: u32 = 0; i < 4; i++) { + scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + } + var tmp: u32 = scale_vals[2]; + scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); + scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + // convert arrays of f16 -> u32 + var hmask_vals: array; + for (var i: u32 = 0; i < 8; i++) { + hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + } + var qs_vals: array; + for (var i: u32 = 0; i < 16; i++) { + qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + } + + var dst_i = dst_base + offset * 256; + var is: u32 = 0; + var m: u32 = 1; + // 2 halves of the block (128 elements each) + for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { + // 4 groups (each group has 2 blocks of 16 elements) + for (var shift: u32 = 0; shift < 8; shift += 2) { + // 2 blocks + for (var k: u32 = 0; k < 32; k += 16) { + let sc = get_byte(scale_vals[is / 4], is % 4); + is++; + let dl = d * (f32(sc) - 32.0); + for (var l: u32 = 0u; l < 16u; l++) { + let q_idx = q_b_idx + k + l; + let hm_idx = k + l; + let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); + let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); + let hm = select(4.0, 0.0, (hmask_byte & m) != 0); + let qs_val = (q_byte >> shift) & 3; + dst[dst_i] = (f32(qs_val) - hm) * dl; + dst_i++; + } + } + m <<= 1; + } + } +} +#enddecl(Q3_K) + +#decl(Q4_K) +// 8 blocks of 32 elements each +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var dst_i = dst_base + offset * 256; + var is: u32 = 0; + // 2 blocks each iteration + for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { + for (var shift: u32 = 0; shift < 8; shift += 4) { + let scale_min = get_scale_min(is, block.scales); + is++; + let dl = d * scale_min.x; + let ml = m * scale_min.y; + for (var l: u32 = 0; l < 32; l++) { + let q_idx = q_b_idx + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qs_val = (q_byte >> shift) & 0xF; + dst[dst_i] = (f32(qs_val) * dl - ml); + dst_i++; + } + } + } +} +#enddecl(Q4_K) + +#decl(Q5_K) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var dst_i = dst_base + offset * 256; + var is: u32 = 0; + var u: u32 = 1; + // 2 blocks each iteration + for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { + for (var shift: u32 = 0; shift < 8; shift += 4) { + let scale_min = get_scale_min(is, block.scales); + is++; + let dl = d * scale_min.x; + let ml = m * scale_min.y; + for (var l: u32 = 0; l < 32; l++) { + let q_idx = q_b_idx + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qh_byte = get_byte(block.qh[l / 4], l % 4); + let qs_val = (q_byte >> shift) & 0xF; + let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); + dst[dst_i] = (f32(qs_val) + qh_val) * dl - ml; + dst_i++; + } + u <<= 1; + } + } +} +#enddecl(Q5_K) + +#decl(Q6_K) +// 16 blocks of 16 elements each +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + + // convert arrays of f16 -> u32 + var ql_vals: array; + for (var i: u32 = 0; i < 32; i++) { + ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + } + var qh_vals: array; + for (var i: u32 = 0; i < 16; i++) { + qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + } + var scale_vals: array; + for (var i: u32 = 0; i < 4; i++) { + scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + } + + var dst_i = dst_base + offset * 256; + var qh_b_idx: u32 = 0; + var sc_b_idx: u32 = 0; + for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { + for (var l: u32 = 0; l < 32; l++) { + let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4); + let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4); + let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4); + + let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; + let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; + let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; + let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; + + let is = l/16; + let is1 = sc_b_idx + is; + let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4); + let is2 = sc_b_idx + is + 2; + let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4); + let is3 = sc_b_idx + is + 4; + let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4); + let is4 = sc_b_idx + is + 6; + let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4); + + dst[dst_i + l] = (q1 * f32(sc1)) * d; + dst[dst_i + l + 32] = (q2 * f32(sc2)) * d; + dst[dst_i + l + 64] = (q3 * f32(sc3)) * d; + dst[dst_i + l + 96] = (q4 * f32(sc4)) * d; + } + dst_i += 128; + qh_b_idx += 32; + sc_b_idx += 8; + } +} + +#enddecl(Q6_K) + +#decl(IQ2_XXS) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 256; + for (var ib: u32 = 0; ib < 32; ib += 4) { + let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); + let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; + for (var l: u32 = 0; l < 4; l++) { + let ig = get_byte(aux0, l) * 8; + let is = (aux1 >> (7 * l)) & 127; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + dst[dst_i] = db * f32(g) * m; + dst_i++; + } + } + } +} +#enddecl(IQ2_XXS) + +#decl(IQ2_XS) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 256; + var scale_vals = array( + bitcast(vec2(block.scales[0], block.scales[1])), + bitcast(vec2(block.scales[2], block.scales[3])) + ); + for (var ib: u32 = 0; ib < 32; ib += 4) { + let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); + let db = array( + d * (0.5 + f32(s & 0xF)) * 0.25, + d * (0.5 + f32(s >> 4)) * 0.25 + ); + for (var l: u32 = 0; l < 4; l++) { + let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); + let ig = (qs_val & 511) * 8; + let is = qs_val >> 9; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + let dl = db[l/2]; + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + dst[dst_i] = dl * f32(g) * m; + dst_i++; + } + } + } +} +#enddecl(IQ2_XS) + +#decl(IQ2_S) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 256; + var qs_vals : array; + for (var i: u32 = 0; i < 16; i++) { + qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + } + var qh_vals = array( + bitcast(vec2(block.qh[0], block.qh[1])), + bitcast(vec2(block.qh[2], block.qh[3])) + ); + var scale_vals = array( + bitcast(vec2(block.scales[0], block.scales[1])), + bitcast(vec2(block.scales[2], block.scales[3])) + ); + for (var ib: u32 = 0; ib < 8; ib ++) { + let s = get_byte(scale_vals[ib / 4], ib % 4); + let db = array( + d * (0.5 + f32(s & 0xF)) * 0.25, + d * (0.5 + f32(s >> 4)) * 0.25 + ); + let qs_w = qs_vals[ib]; + for (var l: u32 = 0; l < 4; l++) { + let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300; + let ig = (get_byte(qs_w, l) | qh_b) * 8; + let signs = get_byte(qs_vals[ib + 8], l); + let dl = db[l/2]; + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + dst[dst_i] = dl * f32(g) * m; + dst_i++; + } + } + } +} + +#enddecl(IQ2_S) + +#decl(IQ3_XSS) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 256; + for (var ib: u32 = 0; ib < 16; ib += 2) { + let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; + for (var l: u32 = 0; l < 4; l++) { + let is = (sc_sign >> (7 * l)) & 127; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); + let ig1 = get_byte(ig_val, 0); + let ig2 = get_byte(ig_val, 1); + for (var j: u32 = 0; j < 4; j++) { + let g1 = get_byte(iq3xxs_grid[ig1], j); + let g2 = get_byte(iq3xxs_grid[ig2], j); + let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); + let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); + dst[dst_i] = db * f32(g1) * m1; + dst[dst_i + 4] = db * f32(g2) * m2; + dst_i++; + } + dst_i += 4; + } + } +} +#enddecl(IQ3_XSS) + +#decl(IQ3_S) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 256; + var qh_vals = array( + bitcast(vec2(block.qh[0], block.qh[1])), + bitcast(vec2(block.qh[2], block.qh[3])) + ); + var sign_vals: array; + for (var i: u32 = 0; i < 8; i++) { + sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + } + var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); + for (var ib: u32 = 0; ib < 4; ib++) { + let s = get_byte(scale_vals, ib); + let db = array( + d * (1.0 + 2.0 * f32(s & 0xF)), + d * (1.0 + 2.0 * f32(s >> 4)) + ); + for (var k: u32 = 0; k < 2; k++) { + let dl = db[k]; + let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k); + let sign_w = sign_vals[ib * 2 + k]; + for (var l: u32 = 0; l < 4; l++) { + let signs = get_byte(sign_w, l); + let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); + let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); + for (var j: u32 = 0; j < 4; j++) { + let g1 = get_byte(iq3s_grid[ig1], j); + let g2 = get_byte(iq3s_grid[ig2], j); + let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); + let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); + dst[dst_i] = dl * f32(g1) * m1; + dst[dst_i + 4] = dl * f32(g2) * m2; + dst_i++; + } + dst_i += 4; + } + } + } +} +#enddecl(IQ3_S) + +#decl(IQ1_S) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 256; + for (var ib: u32 = 0; ib < 8; ib++) { + let qh = bitcast(vec2(block.qh[ib], 0.0)); + let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); + let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + for (var l: u32 = 0; l < 4; l++) { + let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; + for (var j: u32 = 0; j < 8; j++) { + let gw = iq1_grid[(ig + j) / 16]; + let g = (gw >> (((ig + j) % 16) * 2)) & 3; + let gs = bitcast(g << 30) >> 30; + dst[dst_i] = dl * (f32(gs) + delta); + dst_i++; + } + } + } +} + +#enddecl(IQ1_S) + +#decl(IQ1_M) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + + let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000); + let d = f32(bitcast>(scale).x); + var dst_i = dst_base + offset * 256; + for (var ib: u32 = 0; ib < 8; ib++) { + let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF; + let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7; + let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7; + var dl = array( + d * f32(2 * s1 + 1), + d * f32(2 * s2 + 1) + ); + + let qh = block.qh[ib / 2] >> (16 * (ib % 2)); + var idx = array( + get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700), + get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700), + get_byte(block.qs[ib], 2) | ((qh) & 0x700), + get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700) + ); + var delta = array( + select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0), + select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0), + select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0), + select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0) + ); + for (var l: u32 = 0; l < 4; l++) { + let ig = idx[l] * 8; + for (var j: u32 = 0; j < 8; j++) { + let gw = iq1_grid[(ig + j) / 16]; + let g = (gw >> (((ig + j) % 16) * 2)) & 3; + let gs = bitcast(g << 30) >> 30; + dst[dst_i] = dl[l/2] * (f32(gs) + delta[l]); + dst_i++; + } + } + } +} + +#enddecl(IQ1_M) + +#decl(IQ4_NL) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + var dst_i = dst_base + offset * 32; + var qs: array; + for (var i: u32 = 0; i < 4; i++) { + qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + } + for (var j: u32 = 0; j < 16; j++) { + let qsb = get_byte(qs[j / 4], j % 4); + dst[dst_i] = d * f32(kvalues_iq4nl[qsb & 0xF]); + dst[dst_i + 16] = d * f32(kvalues_iq4nl[qsb >> 4]); + dst_i++; + } +} +#enddecl(IQ4_NL) + +#decl(IQ4_XS) +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block = src[src_base + offset]; + let d = f32(block.d); + let scales_h = bitcast(vec2(block.scales_h, 0.0)); + var dst_i = dst_base + offset * 256; + for (var ib: u32 = 0; ib < 8; ib++) { + let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); + let dl = d * (f32(ls) - 32.0); + for (var j: u32 = 0; j < 16; j++) { + let iqs = ib * 16 + j; + let qsb = get_byte(block.qs[iqs / 4], iqs % 4); + dst[dst_i] = dl * f32(kvalues_iq4nl[qsb & 0xF]); + dst[dst_i + 16] = dl * f32(kvalues_iq4nl[qsb >> 4]); + dst_i++; + } + dst_i += 16; + } +} +#enddecl(IQ4_XS) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +DECLS + +@group(0) @binding(0) +var src: array<{{TYPE}}>; + +@group(0) @binding(1) +var idx: array; + +@group(0) @binding(2) +var dst: array<{{DST_TYPE}}>; + +struct Params { + offset_src: u32, // in elements + offset_idx: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_idx0: u32, + stride_idx1: u32, + stride_idx2: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of dst + ne0: u32, + n_rows: u32, + ne2: u32, + ne3: u32, + + // Shape of idx + idx1: u32, + idx2: u32, +}; + +@group(0) @binding(3) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.n_rows * params.ne2 * params.ne3) { + return; + } + var i = gid.x; + let i_dst3 = i / (params.ne2 * params.n_rows); + + i = i % (params.ne2 * params.n_rows); + let i_dst2 = i / params.n_rows; + let i_dst1 = i % params.n_rows; + + let i_idx2 = i_dst3 % params.idx2; + let i_idx1 = i_dst2 % params.idx1; + let i_idx0 = i_dst1; + + let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + + let idx_val = u32(idx[i_idx]); + + let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3; + let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3; + + for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) { + copy_elements(i_src_row, i_dst_row, i); + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl index cb7c8c3e09e..194d2d6f58c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl @@ -19,20 +19,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let start = params.offset; let end = params.offset + params.size; - for (var j: u32 = 0u; j < bytes_per_thread; j = j + 1u) { + for (var j: u32 = 0u; j < bytes_per_thread; j += 4) { let byte_index = start + i + j; - if (byte_index + 4u <= end) { - output_buffer[(byte_index >> 2u)] = params.value; + if (byte_index + 4 <= end) { + output_buffer[byte_index >> 2] = params.value; } else { // Handle tail (unaligned) - for (var k: u32 = 0u; k < 4u; k = k + 1u) { + for (var k: u32 = 0; k < 4; k++) { let idx = byte_index + k; if (idx < end) { - let word_idx = idx >> 2u; - let byte_offset = (idx & 3u) * 8u; - let mask = ~(0xffu << byte_offset); + let word_idx = idx >> 2; + let bit_offset = (idx & 3) * 8u; + let mask = ~(0xffu << bit_offset); let existing = output_buffer[word_idx]; - output_buffer[word_idx] = (existing & mask) | ((params.value & 0xffu) << byte_offset); + output_buffer[word_idx] = (existing & mask) | (params.value & (0xffu << bit_offset)); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl new file mode 100644 index 00000000000..12506e1420e --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl @@ -0,0 +1,44 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "f32", + } + }, + { + "REPLS": { + "TYPE" : "f16", + } + } +] + +#end(VARIANTS) + +#define(SHADER) + +enable f16; + +#include "binary_head.tmpl" + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)]; + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl new file mode 100644 index 00000000000..e467e59edb4 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl @@ -0,0 +1,41 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "f32", + } + }, + { + "REPLS": { + "TYPE" : "f16", + } + } +] + +#end(VARIANTS) + +#define(SHADER) + +enable f16; + +#include "binary_head.tmpl" + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)]; + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl new file mode 100644 index 00000000000..25e2185de84 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl @@ -0,0 +1,907 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "BLOCK_SIZE" : 1 + }, + "DECLS" : ["FLOAT"] + }, + { + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "BLOCK_SIZE" : 1 + }, + "DECLS" : ["FLOAT"] + }, + { + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "BLOCK_SIZE" : 1 + }, + "DECLS" : ["FLOAT"] + }, + { + "REPLS": { + "SRC0_TYPE": "q4_0", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"] + }, + { + "REPLS": { + "SRC0_TYPE": "q4_1", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"] + }, + { + "REPLS": { + "SRC0_TYPE": "q5_0", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"] + }, + { + "REPLS": { + "SRC0_TYPE": "q5_1", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"] + }, + { + "REPLS": { + "SRC0_TYPE": "q8_0", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"] + }, + { + "REPLS": { + "SRC0_TYPE": "q2_k", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"] + }, + { + "REPLS": { + "SRC0_TYPE": "q3_k", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"] + }, + { + "REPLS": { + "SRC0_TYPE": "q4_k", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"] + }, + { + "REPLS": { + "SRC0_TYPE": "q5_k", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"] + }, + { + "REPLS": { + "SRC0_TYPE": "q6_k", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq2_xxs", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq2_xs", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq2_s", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq3_xxs", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq3_s", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq1_s", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq1_m", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq4_nl", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 32, + }, + "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq4_xs", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256, + }, + "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(FLOAT) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); +} +#enddecl(FLOAT) + +#decl(Q4_0) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q4_0 = src0[src0_idx_base + offset]; + let d = f32(block_q4_0.d); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 4; j++) { + let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; + let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_lo * f32(src1[src1_offset]); + sum += q_hi * f32(src1[src1_offset + 16]); + } + } + return sum; +} +#enddecl(Q4_0) + +#decl(Q4_1) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q4_1 = src0[src0_idx_base + offset]; + let d = f32(block_q4_1.d); + let m = f32(block_q4_1.m); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 4; j++) { + let q_packed = block_q4_1.qs[j]; + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32((q_byte >> 4) & 0xF) * d + m; + let q_lo = f32(q_byte & 0xF) * d + m; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_lo * f32(src1[src1_offset]); + sum += q_hi * f32(src1[src1_offset + 16]); + } + } + return sum; +} +#enddecl(Q4_1) + +#decl(Q5_0) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q5_0 = src0[src0_idx_base + offset]; + let d = f32(block_q5_0.d); + var sum: f32 = 0.0; + let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + for (var j: u32 = 0; j < 4; j++) { + let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; + let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; + let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_lo * f32(src1[src1_offset]); + sum += q_hi * f32(src1[src1_offset + 16]); + } + } + return sum; +} +#enddecl(Q5_0) + +#decl(Q5_1) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q5_1 = src0[src0_idx_base + offset]; + let d = f32(block_q5_1.d); + let m = f32(block_q5_1.m); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 4; j++) { + let q_packed = block_q5_1.qs[j]; + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10; + let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m; + let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10; + let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_lo * f32(src1[src1_offset]); + sum += q_hi * f32(src1[src1_offset + 16]); + } + } + return sum; +} +#enddecl(Q5_1) + +#decl(Q8_0) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q8_0 = src0[src0_idx_base + offset]; + let d = f32(block_q8_0.d); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 8; j++) { + let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_val * f32(src1[src1_offset]); + } + } + return sum; +} +#enddecl(Q8_0) + +#decl(Q8_1) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q8_1 = src0[src0_idx_base + offset]; + let d = f32(block_q8_1.d); + let m = f32(block_q8_1.m); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 8; j++) { + let q_packed = block_q8_1.qs[j]; + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d + m; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_val * f32(src1[src1_offset]); + } + } + return sum; +} +#enddecl(Q8_1) + +#decl(Q2_K) +// 16 blocks of 16 elements each +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var is: u32 = 0; + // 2 halves of the block (128 elements each) + for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { + // 4 groups (each group has 2 blocks of 16 elements) + for (var shift: u32 = 0; shift < 8; shift += 2) { + // 2 blocks + for (var k: u32 = 0; k < 32; k += 16) { + let sc = get_byte(block.scales[is / 4], is % 4); + is++; + let dl = d * f32(sc & 0xF); + let ml = m * f32(sc >> 4); + for (var l: u32 = 0u; l < 16; l++) { + let q_idx = q_b_idx + k + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qs_val = (q_byte >> shift) & 3; + sum += (f32(qs_val) * dl - ml) * src1[src1_i]; + src1_i++; + } + } + } + } + return sum; +} + +#enddecl(Q2_K) + +#decl(Q3_K) +// 16 blocks of 16 elements each +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + + // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, + // and 2-bits from the last 4 bytes + let kmask1: u32 = 0x03030303; + let kmask2: u32 = 0x0f0f0f0f; + var scale_vals: array; + for (var i: u32 = 0; i < 4; i++) { + scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + } + var tmp: u32 = scale_vals[2]; + scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); + scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + // convert arrays of f16 -> u32 + var hmask_vals: array; + for (var i: u32 = 0; i < 8; i++) { + hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + } + var qs_vals: array; + for (var i: u32 = 0; i < 16; i++) { + qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + } + + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var is: u32 = 0; + var m: u32 = 1; + // 2 halves of the block (128 elements each) + for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { + // 4 groups (each group has 2 blocks of 16 elements) + for (var shift: u32 = 0; shift < 8; shift += 2) { + // 2 blocks + for (var k: u32 = 0; k < 32; k += 16) { + let sc = get_byte(scale_vals[is / 4], is % 4); + is++; + let dl = d * (f32(sc) - 32.0); + for (var l: u32 = 0u; l < 16u; l++) { + let q_idx = q_b_idx + k + l; + let hm_idx = k + l; + let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); + let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); + let hm = select(4.0, 0.0, (hmask_byte & m) != 0); + let qs_val = (q_byte >> shift) & 3; + sum += ((f32(qs_val) - hm) * dl) * src1[src1_i]; + src1_i++; + } + } + m <<= 1; + } + } + return sum; +} + +#enddecl(Q3_K) + +#decl(Q4_K) +// 8 blocks of 32 elements each +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var is: u32 = 0; + // 2 blocks each iteration + for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { + for (var shift: u32 = 0; shift < 8; shift += 4) { + let scale_min = get_scale_min(is, block.scales); + is++; + let dl = d * scale_min.x; + let ml = m * scale_min.y; + for (var l: u32 = 0; l < 32; l++) { + let q_idx = q_b_idx + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qs_val = (q_byte >> shift) & 0xF; + sum += (f32(qs_val) * dl - ml) * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} + +#enddecl(Q4_K) + +#decl(Q5_K) +// 8 blocks of 32 elements each +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var is: u32 = 0; + var u: u32 = 1; + // 2 blocks each iteration + for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { + for (var shift: u32 = 0; shift < 8; shift += 4) { + let scale_min = get_scale_min(is, block.scales); + is++; + let dl = d * scale_min.x; + let ml = m * scale_min.y; + for (var l: u32 = 0; l < 32; l++) { + let q_idx = q_b_idx + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qh_byte = get_byte(block.qh[l / 4], l % 4); + let qs_val = (q_byte >> shift) & 0xF; + let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); + sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i]; + src1_i++; + } + u <<= 1; + } + } + return sum; +} + +#enddecl(Q5_K) + +#decl(Q6_K) +// 16 blocks of 16 elements each +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + + // convert arrays of f16 -> u32 + var ql_vals: array; + for (var i: u32 = 0; i < 32; i++) { + ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + } + var qh_vals: array; + for (var i: u32 = 0; i < 16; i++) { + qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + } + var scale_vals: array; + for (var i: u32 = 0; i < 4; i++) { + scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + } + + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var qh_b_idx: u32 = 0; + var sc_b_idx: u32 = 0; + for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { + for (var l: u32 = 0; l < 32; l++) { + let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4); + let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4); + let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4); + + let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; + let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; + let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; + let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; + + let is = l/16; + let is1 = sc_b_idx + is; + let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4); + let is2 = sc_b_idx + is + 2; + let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4); + let is3 = sc_b_idx + is + 4; + let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4); + let is4 = sc_b_idx + is + 6; + let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4); + + sum += d * f32(sc1) * q1 * src1[src1_i + l]; + sum += d * f32(sc2) * q2 * src1[src1_i + l + 32]; + sum += d * f32(sc3) * q3 * src1[src1_i + l + 64]; + sum += d * f32(sc4) * q4 * src1[src1_i + l + 96]; + } + src1_i += 128; + qh_b_idx += 32; + sc_b_idx += 8; + } + return sum; +} + +#enddecl(Q6_K) + +#decl(IQ2_XXS) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 32; ib += 4) { + let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); + let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; + for (var l: u32 = 0; l < 4; l++) { + let ig = get_byte(aux0, l) * 8; + let is = (aux1 >> (7 * l)) & 127; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + sum += db * f32(g) * m * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} + +#enddecl(IQ2_XXS) + +#decl(IQ2_XS) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var scale_vals = array( + bitcast(vec2(block.scales[0], block.scales[1])), + bitcast(vec2(block.scales[2], block.scales[3])) + ); + var sum = 0.0; + for (var ib: u32 = 0; ib < 32; ib += 4) { + let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); + let db = array( + d * (0.5 + f32(s & 0xF)) * 0.25, + d * (0.5 + f32(s >> 4)) * 0.25 + ); + for (var l: u32 = 0; l < 4; l++) { + let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); + let ig = (qs_val & 511) * 8; + let is = qs_val >> 9; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + let dl = db[l/2]; + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + sum += dl * f32(g) * m * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} + +#enddecl(IQ2_XS) + +#decl(IQ2_S) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var qs_vals : array; + for (var i: u32 = 0; i < 16; i++) { + qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + } + var qh_vals = array( + bitcast(vec2(block.qh[0], block.qh[1])), + bitcast(vec2(block.qh[2], block.qh[3])) + ); + var scale_vals = array( + bitcast(vec2(block.scales[0], block.scales[1])), + bitcast(vec2(block.scales[2], block.scales[3])) + ); + var sum = 0.0; + for (var ib: u32 = 0; ib < 8; ib ++) { + let s = get_byte(scale_vals[ib / 4], ib % 4); + let db = array( + d * (0.5 + f32(s & 0xF)) * 0.25, + d * (0.5 + f32(s >> 4)) * 0.25 + ); + let qs_w = qs_vals[ib]; + for (var l: u32 = 0; l < 4; l++) { + let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300; + let ig = (get_byte(qs_w, l) | qh_b) * 8; + let signs = get_byte(qs_vals[ib + 8], l); + let dl = db[l/2]; + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + sum += dl * f32(g) * m * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} + + +#enddecl(IQ2_S) + +#decl(IQ3_XSS) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 16; ib += 2) { + let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; + for (var l: u32 = 0; l < 4; l++) { + let is = (sc_sign >> (7 * l)) & 127; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); + let ig1 = get_byte(ig_val, 0); + let ig2 = get_byte(ig_val, 1); + for (var j: u32 = 0; j < 4; j++) { + let g1 = get_byte(iq3xxs_grid[ig1], j); + let g2 = get_byte(iq3xxs_grid[ig2], j); + let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); + let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); + sum += db * f32(g1) * m1 * src1[src1_i]; + sum += db * f32(g2) * m2 * src1[src1_i + 4]; + src1_i++; + } + src1_i += 4; + } + } + return sum; +} + +#enddecl(IQ3_XSS) + +#decl(IQ3_S) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var qh_vals = array( + bitcast(vec2(block.qh[0], block.qh[1])), + bitcast(vec2(block.qh[2], block.qh[3])) + ); + var sign_vals: array; + for (var i: u32 = 0; i < 8; i++) { + sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + } + var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); + var sum = 0.0; + for (var ib: u32 = 0; ib < 4; ib++) { + let s = get_byte(scale_vals, ib); + let db = array( + d * (1.0 + 2.0 * f32(s & 0xF)), + d * (1.0 + 2.0 * f32(s >> 4)) + ); + for (var k: u32 = 0; k < 2; k++) { + let dl = db[k]; + let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k); + let sign_w = sign_vals[ib * 2 + k]; + for (var l: u32 = 0; l < 4; l++) { + let signs = get_byte(sign_w, l); + let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); + let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); + for (var j: u32 = 0; j < 4; j++) { + let g1 = get_byte(iq3s_grid[ig1], j); + let g2 = get_byte(iq3s_grid[ig2], j); + let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); + let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); + sum += dl * f32(g1) * m1 * src1[src1_i]; + sum += dl * f32(g2) * m2 * src1[src1_i + 4]; + src1_i++; + } + src1_i += 4; + } + } + } + return sum; +} +#enddecl(IQ3_S) + +#decl(IQ1_S) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 8; ib++) { + let qh = bitcast(vec2(block.qh[ib], 0.0)); + let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); + let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + for (var l: u32 = 0; l < 4; l++) { + let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; + for (var j: u32 = 0; j < 8; j++) { + let gw = iq1_grid[(ig + j) / 16]; + let g = (gw >> (((ig + j) % 16) * 2)) & 3; + let gs = bitcast(g << 30) >> 30; + sum += dl * (f32(gs) + delta) * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} + +#enddecl(IQ1_S) + +#decl(IQ1_M) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + + let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000); + let d = f32(bitcast>(scale).x); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 8; ib++) { + let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF; + let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7; + let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7; + var dl = array( + d * f32(2 * s1 + 1), + d * f32(2 * s2 + 1) + ); + + let qh = block.qh[ib / 2] >> (16 * (ib % 2)); + var idx = array( + get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700), + get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700), + get_byte(block.qs[ib], 2) | ((qh) & 0x700), + get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700) + ); + var delta = array( + select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0), + select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0), + select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0), + select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0) + ); + for (var l: u32 = 0; l < 4; l++) { + let ig = idx[l] * 8; + for (var j: u32 = 0; j < 8; j++) { + let gw = iq1_grid[(ig + j) / 16]; + let g = (gw >> (((ig + j) % 16) * 2)) & 3; + let gs = bitcast(g << 30) >> 30; + sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} + +#enddecl(IQ1_M) + +#decl(IQ4_NL) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 32; + var sum = 0.0; + var qs: array; + for (var i: u32 = 0; i < 4; i++) { + qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + } + for (var j: u32 = 0; j < 16; j++) { + let qsb = get_byte(qs[j / 4], j % 4); + sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; + sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; + src1_i++; + } + return sum; +} + +#enddecl(IQ4_NL) + +#decl(IQ4_XS) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + let scales_h = bitcast(vec2(block.scales_h, 0.0)); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 8; ib++) { + let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); + let dl = d * (f32(ls) - 32.0); + for (var j: u32 = 0; j < 16; j++) { + let iqs = ib * 16 + j; + let qsb = get_byte(block.qs[iqs / 4], iqs % 4); + sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; + sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; + src1_i++; + } + src1_i += 16; + } + return sum; +} + +#enddecl(IQ4_XS) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +DECLS + +struct MulMatParams { + offset_src0: u32, // in elements/blocks + offset_src1: u32, // in elements/blocks + offset_dst: u32, // in elements/blocks + m: u32, + n: u32, + k: u32, + // all strides are in elements/blocks + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // N rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns + +@group(0) @binding(3) var params: MulMatParams; + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (global_id.x >= total) { + return; + } + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = global_id.x / dst3_stride; + let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension + let src13_idx = dst3_idx; // src1 is not broadcast + let dst3_rem = global_id.x % dst3_stride; + + let dst2_idx = dst3_rem / dst2_stride; + let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension + let src12_idx = dst2_idx; // src1 is not broadcast + + let dst2_rem = dst3_rem % dst2_stride; + + let row = dst2_rem / params.n; // output row + let col = dst2_rem % params.n; // output column + + let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; + + var sum = 0.0; + for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) { + sum += multiply_add(src0_idx_base, src1_idx_base, i); + } + dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum; +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl deleted file mode 100644 index 054aab566f9..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ /dev/null @@ -1,56 +0,0 @@ -struct MulMatParams { - m: u32, - n: u32, - k: u32, - // all strides are in elements - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var src0: array; // N rows, K columns -@group(0) @binding(1) var src1: array; // M rows, K columns (transposed) -@group(0) @binding(2) var dst: array; // M rows, N columns - -@group(0) @binding(3) var params: MulMatParams; - -@compute @workgroup_size(64) -fn main(@builtin(global_invocation_id) global_id: vec3) { - let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - if (global_id.x >= total) { - return; - } - - let dst2_stride = params.m * params.n; - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - - let dst3_idx = global_id.x / dst3_stride; - let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension - let src13_idx = dst3_idx; // src1 is not broadcast - let dst3_rem = global_id.x % dst3_stride; - - let dst2_idx = dst3_rem / dst2_stride; - let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension - let src12_idx = dst2_idx; // src1 is not broadcast - - let dst2_rem = dst3_rem % dst2_stride; - - let row = dst2_rem / params.n; // output row - let col = dst2_rem % params.n; // output column - - var sum = 0.0; - for (var i: u32 = 0u; i < params.k; i = i + 1u) { - let src0_idx = src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01 + i; - let src1_idx = src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11 + i; - sum = sum + src0[src0_idx] * src1[src1_idx]; - } - dst[dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum; -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl new file mode 100644 index 00000000000..f919a513363 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl @@ -0,0 +1,57 @@ +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var dst: array; + +struct Params { + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src/dst + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + eps: u32 +}; + +@group(0) @binding(2) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne1 * params.ne2 * params.ne3) { + return; + } + + // one thread per row + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; + let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + var sum = 0.0f; + for (var j: u32 = 0; j < params.ne0; j++) { + sum += src[i_src_row + j] * src[i_src_row + j]; + } + let eps = bitcast(params.eps); + let scale = 1.0/sqrt(sum/f32(params.ne0) + eps); + for (var j: u32 = 0; j < params.ne0; j++) { + dst[i_dst_row + j] = scale * src[i_src_row + j]; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl new file mode 100644 index 00000000000..ae84f556d60 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl @@ -0,0 +1,48 @@ +@group(0) @binding(0) +var a: array; + +struct Params { + offset: u32, // in elements + + // Strides (in elements) + stride1: u32, + stride2: u32, + stride3: u32, + + // Shape + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + eps: u32 +}; + +@group(0) @binding(1) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne1 * params.ne2 * params.ne3) { + return; + } + + // one thread per row + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_row = params.offset + i3 * params.stride3 + i2 * params.stride2 + i1 * params.stride1; + + var sum = 0.0f; + for (var j: u32 = 0; j < params.ne0; j++) { + sum += a[i_row + j] * a[i_row + j]; + } + let eps = bitcast(params.eps); + let scale = 1.0/sqrt(sum/f32(params.ne0) + eps); + for (var j: u32 = 0; j < params.ne0; j++) { + a[i_row + j] = scale * a[i_row + j]; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl index 4bd6f94a235..3567713dc21 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl @@ -52,7 +52,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } var i = gid.x; let i_src3 = i / (params.ne2 * params.n_rows); - let i_dst3 = i / (params.ne2 * 3); i = i % (params.ne2 * params.n_rows); let i_src2 = i / params.n_rows; diff --git a/ggml/src/ggml-zdnn/ggml-zdnn-impl.h b/ggml/src/ggml-zdnn/ggml-zdnn-impl.h index 9dcb040fa89..a4153818158 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +++ b/ggml/src/ggml-zdnn/ggml-zdnn-impl.h @@ -76,6 +76,7 @@ struct ggml_backend_zdnn_context { struct ggml_backend_zdnn_buffer { void * data; + ggml_backend_zdnn_buffer * extra; // for bias, etc. size_t size; zdnn_tensor_desc pre_tfm_desc; diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index 7507a52aeae..57a8f266201 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -115,9 +115,7 @@ static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_ten ggml_backend_zdnn_buffer * weights_extra = (ggml_backend_zdnn_buffer *)weights->extra; ggml_backend_zdnn_buffer * inputs_extra = (ggml_backend_zdnn_buffer *)inputs->extra; ggml_backend_zdnn_buffer * output_extra = (ggml_backend_zdnn_buffer *)output->extra; - - zdnn_tensor_desc ptd_bias, td_bias; - zdnn_ztensor zt_bias; + ggml_backend_zdnn_buffer * bias_extra = (ggml_backend_zdnn_buffer *)output_extra->extra; const int64_t weights_rows = ne01; const int64_t weights_cols = ne00; @@ -129,14 +127,6 @@ static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_ten const int64_t output_rows = ne1; const int64_t output_cols = ne0; - const int64_t bias_dim [GGML_MAX_DIMS] = { 1, 1, 1, output_cols }; - ggml_zdnn_create_tensor(ptd_bias, td_bias, zt_bias, output, bias_dim, ZDNN_1D); - - void * bias_data = (void *)calloc(ne0, ggml_element_size(output)); - if (weights_extra->ztensor.is_transformed == false) ggml_zdnn_load_tensor(weights_extra->ztensor, weights->data); - if (inputs_extra->ztensor.is_transformed == false) ggml_zdnn_load_tensor(inputs_extra->ztensor, inputs->data); - ggml_zdnn_load_tensor(zt_bias, bias_data); - // GGML_LOG_INFO("%s: tensor '%s' tensor dimensions: [%ld, %ld, %ld, %ld] pre_tfm_desc dimensions: [%ld, %ld, %ld, %ld]\n", // __func__, weights_extra->name, // weights->ne[3], weights->ne[2], weights->ne[1], weights->ne[0], @@ -158,29 +148,21 @@ static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_ten GGML_ASSERT(inputs_extra->pre_tfm_desc.dim1 == inputs->ne[0] && "inputs_extra->pre_tfm_desc.dim1 must match inputs->ne[0]"); GGML_ASSERT(inputs_extra->pre_tfm_desc.dim2 == inputs->ne[1] && "inputs_extra->pre_tfm_desc.dim2 must match inputs->ne[1]"); - ZDNN_CHECK(zdnn_matmul_transpose_op(&inputs_extra->ztensor, &weights_extra->ztensor, &zt_bias, + ZDNN_CHECK(zdnn_matmul_transpose_op(&inputs_extra->ztensor, &weights_extra->ztensor, &bias_extra->ztensor, false, true, MATMUL_OP_ADDITION, &output_extra->ztensor)); // TODO: Remove in the future as we are currently DLF16 -> FP32 then in the next op, FP32 -> DLF16 again. Inefficient. ZDNN_CHECK(zdnn_transform_origtensor(&output_extra->ztensor, output->data)); - ZDNN_CHECK(zdnn_free_ztensor_buffer(&zt_bias)); - free(bias_data); + GGML_UNUSED(ctx); + GGML_UNUSED(weights_rows); + GGML_UNUSED(weights_cols); + GGML_UNUSED(inputs_rows); + GGML_UNUSED(inputs_cols); + GGML_UNUSED(output_rows); + GGML_UNUSED(output_cols); } static void ggml_zdnn_mul_mat_dispatch(ggml_backend_zdnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - bool use_mul_mat_vec = - (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F16) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; - - bool use_mul_mat_vec_q = - ggml_is_quantized(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - - bool use_mul_mat_q = - ggml_is_quantized(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - // debug helpers // GGML_LOG_INFO("%s: use_mul_mat_vec = %d\n", __func__, use_mul_mat_vec); // GGML_LOG_INFO("%s: use_mul_mat_vec_q = %d\n", __func__, use_mul_mat_vec_q); @@ -192,25 +174,7 @@ static void ggml_zdnn_mul_mat_dispatch(ggml_backend_zdnn_context * ctx, const gg // GGML_LOG_INFO("%s: src0 is contiguous %d, transposed %d, type = %s, name = %s\n", __func__, ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); // GGML_LOG_INFO("%s: src1 is contiguous %d, transposed %d, type = %s, name = %s\n", __func__, ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 - && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) - && src1->ne[2] * src1->ne[3] > 1) { - // general KQ + KQV multi-batch - GGML_LOG_INFO("%s: using zdnn_mul_mat_batched for KQ + KQV multi-batch\n", __func__); - // ggml_zdnn_mul_mat_batched(ctx, src0, src1, dst); - } else if (use_mul_mat_vec) { - GGML_LOG_INFO("%s: using zdnn_op_mul_mat_vec for vector multiplication\n", __func__); - // ggml_zdnn_op_mul_mat(ctx, src0, src1, dst, ggml_zdnn_op_mul_mat_vec, nullptr); - } else if (use_mul_mat_vec_q) { - GGML_LOG_INFO("%s: using zdnn_op_mul_mat_vec_q for quantized vector multiplication\n", __func__); - // ggml_zdnn_op_mul_mat(ctx, src0, src1, dst, ggml_zdnn_op_mul_mat_vec_q, ggml_zdnn_quantize_row_q8_1); - } else if (use_mul_mat_q) { - GGML_LOG_INFO("%s: using zdnn_op_mul_mat_q for quantized matrix multiplication\n", __func__); - // ggml_zdnn_op_mul_mat(ctx, src0, src1, dst, ggml_zdnn_op_mul_mat_q, ggml_zdnn_quantize_mmq_q8_1); - } else { - // GGML_LOG_INFO("%s: using zdnn_op_mul_mat for general matrix multiplication\n", __func__); - ggml_zdnn_mul_mat_op(ctx, src0, src1, dst); - } + ggml_zdnn_mul_mat_op(ctx, src0, src1, dst); } static bool ggml_zdnn_compute_forward(ggml_backend_zdnn_context * ctx, ggml_tensor * dst) { @@ -253,6 +217,8 @@ static enum ggml_status ggml_zdnn_graph_compute(ggml_backend_t backend, ggml_cgr } return GGML_STATUS_SUCCESS; + + GGML_UNUSED(ctx_dev); } static bool ggml_zdnn_supports_op(const ggml_backend_zdnn_device_context * ctx_dev, const ggml_tensor * op) { @@ -266,22 +232,30 @@ static bool ggml_zdnn_supports_op(const ggml_backend_zdnn_device_context * ctx_d case GGML_OP_MUL_MAT: { - const ggml_tensor * src0 = op->src[0]; - const ggml_tensor * src1 = op->src[1]; + const ggml_tensor * weights = op->src[0]; + const ggml_tensor * inputs = op->src[1]; - const int64_t ne10 = src1->ne[0]; - const int64_t ne0 = op->ne[0]; - const int64_t ne1 = op->ne[1]; + const int64_t ne10 = inputs->ne[0]; + const int64_t ne0 = op->ne[0]; + const int64_t ne1 = op->ne[1]; const int64_t max_batch = ctx_dev->max_size; - return ggml_is_matrix(src0) && - ggml_is_matrix(src1) && - ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && - src0->view_src == nullptr && src1->view_src == nullptr && - src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && - (ne0 <= max_batch && ne1 <= max_batch && ne10 <= max_batch); + if (!ggml_is_matrix(weights) || !ggml_is_matrix(inputs) || + !ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) || + weights->view_src != nullptr || inputs->view_src != nullptr || + ne0 > max_batch || ne1 > max_batch || ne10 > max_batch) { + return false; + } + + switch (weights->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + return true; + default: + return false; + } } break; default: @@ -374,10 +348,12 @@ static void ggml_zdnn_free(ggml_backend_zdnn_context * ctx) { static void ggml_backend_zdnn_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_zdnn_buffer_context * ctx = (ggml_backend_zdnn_buffer_context *)buffer->context; - for (int i = 0; i < ctx->n_buffers; i++) { - if (ctx->buffers[i]->ztensor.buffer != NULL && ctx->buffers[i]->ztensor.is_transformed) { - ZDNN_CHECK(zdnn_free_ztensor_buffer(&ctx->buffers[i]->ztensor)); - } + for (const auto & buf_ptr : ctx->buffers) { + ggml_backend_zdnn_buffer * buf = buf_ptr.get(); + + // Free any extra buffer allocated for the tensor. E.g., bias for GGML_OP_MUL_MAT + if (buf->extra != nullptr) free(buf->extra->data); + if (buf->ztensor.buffer_size > 0) ZDNN_CHECK(zdnn_free_ztensor_buffer(&buf->ztensor)); } delete ctx; @@ -402,11 +378,37 @@ static enum ggml_status ggml_backend_zdnn_buffer_init_tensor(ggml_backend_buffer std::unique_ptr zdnn_buffer = std::make_unique(); zdnn_buffer->data = tensor->data; zdnn_buffer->size = tsize; - strncpy(zdnn_buffer->name, tensor->name, GGML_MAX_NAME - 1); + zdnn_buffer->extra = nullptr; + snprintf(zdnn_buffer->name, GGML_MAX_NAME, "%s", tensor->name); ggml_zdnn_init_tensor(zdnn_buffer.get(), tensor); tensor->extra = zdnn_buffer.get(); + switch (tensor->op) { + case GGML_OP_MUL_MAT: + { + std::unique_ptr zdnn_bias_buffer = std::make_unique(); + zdnn_bias_buffer->data = (void *)calloc(tensor->ne[0], ggml_element_size(tensor)); + zdnn_bias_buffer->size = ggml_element_size(tensor) * tensor->ne[0]; + snprintf(zdnn_bias_buffer->name, GGML_MAX_NAME, "%.*s (bias)", + GGML_MAX_NAME - (int)sizeof(" (bias)"), tensor->name); + + const int64_t bias_dim[GGML_MAX_DIMS] = { 1, 1, 1, tensor->ne[0] }; + ggml_zdnn_create_tensor(zdnn_bias_buffer->pre_tfm_desc, + zdnn_bias_buffer->tfm_desc, + zdnn_bias_buffer->ztensor, + tensor, bias_dim, ZDNN_1D); + + ggml_zdnn_load_tensor(zdnn_bias_buffer->ztensor, zdnn_bias_buffer->data); + zdnn_buffer->extra = zdnn_bias_buffer.get(); + + ctx->buffers.push_back(std::move(zdnn_bias_buffer)); + ctx->n_buffers++; + } break; + default: + break; + } + ctx->buffers.push_back(std::move(zdnn_buffer)); ctx->n_buffers++; @@ -414,6 +416,8 @@ static enum ggml_status ggml_backend_zdnn_buffer_init_tensor(ggml_backend_buffer // __func__, tensor->name, buffer_idx, tsize); return GGML_STATUS_SUCCESS; + + GGML_UNUSED(buffer_idx); } static void ggml_backend_zdnn_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { @@ -425,6 +429,13 @@ static void ggml_backend_zdnn_buffer_memset_tensor(ggml_backend_buffer_t buffer, static void ggml_backend_zdnn_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { memcpy((char *)tensor->data + offset, data, size); + ggml_backend_zdnn_buffer * extra = (ggml_backend_zdnn_buffer *)tensor->extra; + + // Fixes the LLAMA_SET_ROWS bug + // see: https://github.com/ggml-org/llama.cpp/issues/15414 + if (tensor->buffer->usage == GGML_BACKEND_BUFFER_USAGE_COMPUTE && extra->ztensor.is_transformed) zdnn_reset_ztensor(&extra->ztensor); + if (extra->ztensor.is_transformed == false) ggml_zdnn_load_tensor(extra->ztensor, tensor->data); + GGML_UNUSED(buffer); } @@ -528,29 +539,6 @@ ggml_backend_buffer_type_t ggml_backend_zdnn_buffer_type(void) { return &ggml_backend_buffer_type_zdnn; } -static const char * ggml_backend_zdnn_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) { - return GGML_ZDNN_NAME "_Mapped"; - - GGML_UNUSED(buft); -} - -static ggml_backend_buffer_type_t ggml_backend_zdnn_buffer_from_ptr_type(void) { - static ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_zdnn = { - /* .iface = */ { - /* .get_name = */ ggml_backend_zdnn_buffer_from_ptr_type_get_name, - /* .alloc_buffer = */ ggml_backend_zdnn_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_zdnn_buffer_type_get_alignment, - /* .get_max_size = */ NULL, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ ggml_backend_zdnn_buffer_type_is_host, - }, - /* .device = */ &g_ggml_backend_zdnn_device, - /* .context = */ NULL, - }; - - return &ggml_backend_buffer_from_ptr_type_zdnn; -} - // // backend // @@ -586,6 +574,7 @@ static ggml_backend_i ggml_backend_zdnn_i = { /* .graph_compute = */ ggml_backend_zdnn_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, }; static ggml_guid_t ggml_backend_zdnn_guid(void) { @@ -593,27 +582,6 @@ static ggml_guid_t ggml_backend_zdnn_guid(void) { return reinterpret_cast((void *)guid_str); } -// TODO: remove in the future -ggml_backend_t ggml_backend_zdnn_init(void) { - ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_zdnn_reg(), 0); - - ggml_backend_zdnn_context * ctx = ggml_zdnn_init(dev); - if (ctx == NULL) { - GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); - return NULL; - } - - ggml_backend_t backend = (ggml_backend_t)malloc(sizeof(ggml_backend)); - *backend = (ggml_backend) { - /* .guid = */ ggml_backend_zdnn_guid(), - /* .iface = */ ggml_backend_zdnn_i, - /* .device = */ dev, - /* .context = */ ctx, - }; - - return backend; -} - bool ggml_backend_is_zdnn(ggml_backend_t backend) { return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_zdnn_guid()); @@ -633,11 +601,15 @@ static const char * ggml_backend_zdnn_device_get_name(ggml_backend_dev_t dev) { static const char * ggml_backend_zdnn_device_get_description(ggml_backend_dev_t dev) { return "IBM Z Neural Network Processing Assist (NNPA)"; + + GGML_UNUSED(dev); } static void ggml_backend_zdnn_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { *free = 0; *total = 0; + + GGML_UNUSED(dev); } static enum ggml_backend_dev_type ggml_backend_zdnn_device_get_type(ggml_backend_dev_t dev) { @@ -654,8 +626,8 @@ static void ggml_backend_zdnn_device_get_props(ggml_backend_dev_t dev, ggml_back props->caps = (ggml_backend_dev_caps) { /* .async = */ false, /* .host_buffer = */ false, - /* .buffer_from_host_ptr = */ true, - /* .events = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false }; } @@ -671,7 +643,7 @@ static ggml_backend_t ggml_backend_zdnn_device_init(ggml_backend_dev_t dev, cons /* .guid = */ ggml_backend_zdnn_guid(), /* .iface = */ ggml_backend_zdnn_i, /* .device = */ dev, - /* .context = */ ctx, + /* .context = */ ctx }; return backend; @@ -685,46 +657,6 @@ static ggml_backend_buffer_type_t ggml_backend_zdnn_device_get_buffer_type(ggml_ GGML_UNUSED(dev); } -static ggml_backend_buffer_t ggml_backend_zdnn_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { - ggml_backend_zdnn_buffer_context * ctx = new ggml_backend_zdnn_buffer_context(); - - ctx->all_data = ptr; - ctx->all_size = size; - ctx->owned = false; - ctx->n_buffers = 0; - - const size_t size_page = sysconf(_SC_PAGESIZE); - - // page-align the data ptr - { - const uintptr_t offs = (uintptr_t) ptr % size_page; - ptr = (void *)((char *)ptr - offs); - size += offs; - } - - size_t size_aligned = size; - if ((size_aligned % size_page) != 0) { - size_aligned += size_page - (size_aligned % size_page); - } - - ggml_backend_zdnn_device_context * ctx_dev = (ggml_backend_zdnn_device_context *)dev->context; - - GGML_ASSERT(ctx_dev->zdnn_device >= 0); - int device = ctx_dev->zdnn_device; GGML_UNUSED(device); - - std::unique_ptr zdnn_buffer = std::make_unique(); - zdnn_buffer->data = ptr; - zdnn_buffer->size = size; - ctx->buffers.push_back(std::move(zdnn_buffer)); - - GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB\n", - __func__, size_aligned / 1024.0 / 1024.0); - - ++ctx->n_buffers; - - return ggml_backend_buffer_init(ggml_backend_zdnn_buffer_from_ptr_type(), ggml_backend_zdnn_buffer_i, ctx, size); -} - static bool ggml_backend_zdnn_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { ggml_backend_zdnn_device_context * ctx_dev = (ggml_backend_zdnn_device_context *) dev->context; @@ -733,8 +665,7 @@ static bool ggml_backend_zdnn_device_supports_op(ggml_backend_dev_t dev, const g static bool ggml_backend_zdnn_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { return - buft->iface.get_name == ggml_backend_zdnn_buffer_type_get_name || - buft->iface.get_name == ggml_backend_zdnn_buffer_from_ptr_type_get_name; + buft->iface.get_name == ggml_backend_zdnn_buffer_type_get_name; GGML_UNUSED(dev); } @@ -748,7 +679,7 @@ static ggml_backend_device_i ggml_backend_zdnn_device_i = { /* .init_backend = */ ggml_backend_zdnn_device_init, /* .get_buffer_type = */ ggml_backend_zdnn_device_get_buffer_type, /* .get_host_buffer_type = */ NULL, - /* .buffer_from_host_ptr = */ ggml_backend_zdnn_device_buffer_from_ptr, + /* .buffer_from_host_ptr = */ NULL, /* .supports_op = */ ggml_backend_zdnn_device_supports_op, /* .supports_buft = */ ggml_backend_zdnn_device_supports_buft, /* .offload_op = */ NULL, @@ -812,7 +743,7 @@ static ggml_backend_reg_i ggml_backend_zdnn_reg_i = { /* .get_name = */ ggml_backend_zdnn_reg_get_name, /* .get_device_count = */ ggml_backend_zdnn_reg_device_count, /* .get_device = */ ggml_backend_zdnn_reg_device_get, - /* .get_proc_address = */ ggml_backend_zdnn_get_proc_address, + /* .get_proc_address = */ ggml_backend_zdnn_get_proc_address }; static void ggml_zdnn_cleanup(void) { @@ -830,13 +761,13 @@ ggml_backend_reg_t ggml_backend_zdnn_reg(void) { g_ggml_backend_zdnn_reg = (ggml_backend_reg) { /* .api_version = */ GGML_ZDNN_VERSION, /* .iface = */ ggml_backend_zdnn_reg_i, - /* .context = */ NULL, + /* .context = */ NULL }; g_ggml_backend_zdnn_device = (ggml_backend_device) { /* .iface = */ ggml_backend_zdnn_device_i, /* .reg = */ &g_ggml_backend_zdnn_reg, - /* .context = */ &g_ggml_ctx_dev_main, + /* .context = */ &g_ggml_ctx_dev_main }; return &g_ggml_backend_zdnn_reg; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a4417f1a17e..3584827dca7 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -974,7 +974,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CONV_TRANSPOSE_1D", "IM2COL", "IM2COL_BACK", + "IM2COL_3D", "CONV_2D", + "CONV_3D", "CONV_2D_DW", "CONV_TRANSPOSE_2D", "POOL_1D", @@ -1017,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1076,7 +1078,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "conv_transpose_1d(x)", "im2col(x)", "im2col_back(x)", + "im2col_3d(x)", "conv_2d(x)", + "conv_3d(x)", "conv_2d_dw(x)", "conv_transpose_2d(x)", "pool_1d(x)", @@ -1119,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -3619,6 +3623,7 @@ struct ggml_tensor * ggml_get_rows( struct ggml_tensor * a, struct ggml_tensor * b) { GGML_ASSERT(a->ne[2] == b->ne[1]); + GGML_ASSERT(a->ne[3] == b->ne[2]); GGML_ASSERT(b->ne[3] == 1); GGML_ASSERT(b->type == GGML_TYPE_I32); @@ -4359,6 +4364,91 @@ struct ggml_tensor * ggml_conv_2d( return result; } +// a: [OC*IC, KD, KH, KW] +// b: [N*IC, ID, IH, IW] +// result: [N*OD, OH, OW, IC * KD * KH * KW] +struct ggml_tensor * ggml_im2col_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t IC, + int s0, // stride width + int s1, // stride height + int s2, // stride depth + int p0, // padding width + int p1, // padding height + int p2, // padding depth + int d0, // dilation width + int d1, // dilation height + int d2, // dilation depth + enum ggml_type dst_type) { + const int64_t N = b->ne[3] / IC; + const int64_t ID = b->ne[2]; + const int64_t IH = b->ne[1]; + const int64_t IW = b->ne[0]; + + const int64_t OC = a->ne[3] / IC; + UNUSED(OC); + const int64_t KD = a->ne[2]; + const int64_t KH = a->ne[1]; + const int64_t KW = a->ne[0]; + const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2); + const int64_t OH = ggml_calc_conv_output_size(IH, KH, s1, p1, d1); + const int64_t OW = ggml_calc_conv_output_size(IW, KW, s0, p0, d0); + + GGML_ASSERT((OD > 0) && "b too small compared to a"); + GGML_ASSERT((OH > 0) && "b too small compared to a"); + GGML_ASSERT((OW > 0) && "b too small compared to a"); + + + const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N}; + + struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne); + int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC}; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_IM2COL_3D; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// a: [OC*IC, KD, KH, KW] +// b: [N*IC, ID, IH, IW] +// result: [N*OC, OD, OH, OW] +struct ggml_tensor * ggml_conv_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t IC, + int s0, // stride width + int s1, // stride height + int s2, // stride depth + int p0, // padding width + int p1, // padding height + int p2, // padding depth + int d0, // dilation width + int d1, // dilation height + int d2 // dilation depth + ) { + struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW] + + int64_t OC = a->ne[3] / IC; + int64_t N = b->ne[3] / IC; + struct ggml_tensor * result = + ggml_mul_mat(ctx, + ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW] + ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC)); // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW] + + int64_t OD = im2col->ne[3] / N; + result = ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW] + result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW] + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW] + + return result; +} + // ggml_conv_2d_sk_p0 struct ggml_tensor * ggml_conv_2d_sk_p0( @@ -4480,6 +4570,56 @@ struct ggml_tensor * ggml_conv_2d_direct( return result; } +// ggml_conv_3d_direct + +struct ggml_tensor * ggml_conv_3d_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int s2, + int p0, + int p1, + int p2, + int d0, + int d1, + int d2, + int c, + int n, + int oc) { + + GGML_ASSERT(a->ne[3] == (int64_t) c * oc); + GGML_ASSERT(b->ne[3] == (int64_t) c * n); + + int64_t ne[4]; + ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); + ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2); + ne[3] = (int64_t) oc * n; + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_set_op_params_i32(result, 0, s0); + ggml_set_op_params_i32(result, 1, s1); + ggml_set_op_params_i32(result, 2, s2); + ggml_set_op_params_i32(result, 3, p0); + ggml_set_op_params_i32(result, 4, p1); + ggml_set_op_params_i32(result, 5, p2); + ggml_set_op_params_i32(result, 6, d0); + ggml_set_op_params_i32(result, 7, d1); + ggml_set_op_params_i32(result, 8, d2); + ggml_set_op_params_i32(result, 9, c); + ggml_set_op_params_i32(result, 10, n); + ggml_set_op_params_i32(result, 11, oc); + + result->op = GGML_OP_CONV_3D; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_conv_transpose_2d_p0 static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { @@ -4658,11 +4798,36 @@ struct ggml_tensor * ggml_pad( int p1, int p2, int p3) { + return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3); +} + +struct ggml_tensor * ggml_pad_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + int lp0, + int rp0, + int lp1, + int rp1, + int lp2, + int rp2, + int lp3, + int rp3 + ) { struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, - a->ne[0] + p0, - a->ne[1] + p1, - a->ne[2] + p2, - a->ne[3] + p3); + a->ne[0] + lp0 + rp0, + a->ne[1] + lp1 + rp1, + a->ne[2] + lp2 + rp2, + a->ne[3] + lp3 + rp3); + + ggml_set_op_params_i32(result, 0, lp0); + ggml_set_op_params_i32(result, 1, rp0); + ggml_set_op_params_i32(result, 2, lp1); + ggml_set_op_params_i32(result, 3, rp1); + ggml_set_op_params_i32(result, 4, lp2); + ggml_set_op_params_i32(result, 5, rp2); + ggml_set_op_params_i32(result, 6, lp3); + ggml_set_op_params_i32(result, 7, rp3); + result->op = GGML_OP_PAD; result->src[0] = a; @@ -4758,12 +4923,8 @@ struct ggml_tensor * ggml_timestep_embedding( struct ggml_tensor * timesteps, int dim, int max_period) { - int actual_dim = dim; - if (dim % 2 != 0) { - actual_dim = dim + 1; - } - struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]); + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, timesteps->ne[0]); ggml_set_op_params_i32(result, 0, dim); ggml_set_op_params_i32(result, 1, max_period); diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 043c5a767a1..f92eac453e4 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -a933eb1ccbfe3b290384a67a9db8eb2b596ccfff +332a82bc193daffc637952bfb1488441d1e59b10