From 530ef06e82cdd7b4eb40aac510351b1d8ffcd649 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 20 Sep 2025 17:27:49 +0700 Subject: [PATCH 1/3] mtmd: more optimized build_rope_2d --- tools/mtmd/clip.cpp | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 61420193daef0..e25b609a2ac41 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -571,7 +571,7 @@ struct clip_graph { ggml_set_input(pos_w); auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { - return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, true); + return build_rope_2d(gf, ctx0, cur, pos_h, pos_w, hparams.rope_theta, true); }; ggml_tensor * inp = build_inp(); @@ -1013,7 +1013,7 @@ struct clip_graph { // first half is X axis and second half is Y axis // ref: https://github.com/huggingface/transformers/blob/40a493c7ed4f19f08eadb0639cf26d49bfa5e180/src/transformers/models/llama4/modeling_llama4.py#L1312 // ref: https://github.com/Blaizzy/mlx-vlm/blob/a57156aa87b33cca6e5ee6cfc14dd4ef8f611be6/mlx_vlm/models/llama4/vision.py#L441 - return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); + return build_rope_2d(gf, ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); }; ggml_tensor * cur = build_vit( inp, n_pos, @@ -1088,7 +1088,7 @@ struct clip_graph { // build ViT with 2D position embeddings auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { // first half is X axis and second half is Y axis - return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); + return build_rope_2d(gf, ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); }; ggml_tensor * inp = build_inp(); @@ -1975,9 +1975,8 @@ struct clip_graph { } // implementation of the 2D RoPE without adding a new op in ggml - // this is not efficient (use double the memory), but works on all backends - // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065 static ggml_tensor * build_rope_2d( + ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * cur, ggml_tensor * pos_a, // first half @@ -2002,16 +2001,10 @@ struct clip_graph { : 1.0; // first half - ggml_tensor * first; { - first = ggml_view_3d(ctx0, cur, - n_dim/2, n_head, n_pos, - ggml_row_size(cur->type, n_dim), - ggml_row_size(cur->type, n_dim*n_head), - 0); - first = ggml_rope_ext( + cur = ggml_rope_ext( ctx0, - first, + cur, pos_a, // positions nullptr, // freq factors n_dim/2, // n_dims @@ -2028,7 +2021,8 @@ struct clip_graph { ggml_row_size(cur->type, n_dim), ggml_row_size(cur->type, n_dim*n_head), n_dim/2 * ggml_element_size(cur)); - second = ggml_rope_ext( + // "second" tensor should be on the same backend as ggml_rope_ext(), therefore we can use inplace version + second = ggml_rope_ext_inplace( ctx0, second, pos_b, // positions @@ -2038,9 +2032,9 @@ struct clip_graph { freq_scale_odd, 0.0f, 1.0f, 0.0f, 0.0f ); + ggml_build_forward_expand(gf, second); } - cur = ggml_concat(ctx0, first, second, 0); return cur; } From 77ae46f18711a741cdc9d57c3242f0a4e195019b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 20 Sep 2025 17:28:02 +0700 Subject: [PATCH 2/3] add test for non-cont inplace rope --- tests/test-backend-ops.cpp | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 507b691dc96e2..686fdd0eb733b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3768,6 +3768,7 @@ struct test_rope : public test_case { const ggml_type type; const std::array ne_a; int n_dims; + int offset; int mode; int n_ctx; // used to generate positions float fs; // freq_scale @@ -3779,16 +3780,17 @@ struct test_rope : public test_case { std::string vars() override { // forward can be inferred from the op, does not need to be printed - return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v); + return VARS_TO_STR11(type, ne_a, n_dims, offset, mode, n_ctx, fs, ef, af, ff, v); } test_rope(ggml_type type = GGML_TYPE_F32, std::array ne_a = {10, 5, 3, 1}, - int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, + int n_dims = 10, int offset = 0, int mode = 0, int n_ctx = 512, float fs = 1.0f, float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true) - : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {} + : type(type), ne_a(ne_a), n_dims(n_dims), offset(offset), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {} ggml_tensor * build_graph(ggml_context * ctx) override { + bool inplace = false; ggml_tensor * a; if (v & 1) { auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3; @@ -3808,6 +3810,14 @@ struct test_rope : public test_case { ggml_set_name(a, "a"); } + if (offset > 0) { + inplace = true; + a = ggml_view_3d(ctx, a, a->ne[0] - offset, a->ne[1], a->ne[2], + ggml_row_size(a->type, a->ne[0]), + ggml_row_size(a->type, a->ne[0]*a->ne[1]), + offset * ggml_element_size(a)); + } + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; const bool is_vision = mode == GGML_ROPE_TYPE_VISION; @@ -3846,12 +3856,12 @@ struct test_rope : public test_case { } } else { if (forward) { - out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + out = inplace + ? ggml_rope_ext_inplace(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f) + : ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); } else { - out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + out = ggml_rope_ext_back (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); } - - // TODO: add test with a non-contiguous view as input ; this case is needed for build_rope_2d in clip.cpp } ggml_set_name(out, "out"); From e2427dbe02374bcdb2f78525d25de52c4e89d9fa Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 20 Sep 2025 18:35:44 +0700 Subject: [PATCH 3/3] fix test definition --- tests/test-backend-ops.cpp | 39 +++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 686fdd0eb733b..a62084c229c33 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6431,37 +6431,38 @@ static std::vector> make_test_cases_eval() { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (bool ff : {false, true}) { // freq_factors for (float v : { 0, 1 }) { - test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B + test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B if (all) { - test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B - test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B - test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B + test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B + test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B + test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B } if (all) { - test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 0, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 0, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 0, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 0, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) - test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) } if (all) { - test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B) - test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B) - test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) + test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, 0, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B) + test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, 0, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B) + test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, 0, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, 0, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, 0, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) } - test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 0, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 64, 64, 2, 512, fs, ef, af, ff, v, fw)); // 2D-RoPE (second half) } }