diff --git a/mllm/models/llama/modeling_llama.hpp b/mllm/models/llama/modeling_llama.hpp index 58a033650..3b4142c68 100644 --- a/mllm/models/llama/modeling_llama.hpp +++ b/mllm/models/llama/modeling_llama.hpp @@ -377,8 +377,10 @@ class LlamaForCausalLM : public nn::Module, public ARGeneration { bool tie_word_embeddings_; bool mask_by_tensor_; + inline nn::AbstractStaticCache& kvCache() { return *kv_cache_; } + private: std::unique_ptr kv_cache_; }; -} // namespace mllm::models::llama \ No newline at end of file +} // namespace mllm::models::llama diff --git a/tools/mllm-llm-benchmark/README.md b/tools/mllm-llm-benchmark/README.md index 2709ecaa5..f9285aca6 100644 --- a/tools/mllm-llm-benchmark/README.md +++ b/tools/mllm-llm-benchmark/README.md @@ -10,7 +10,6 @@ This is a benchmark tool for measuring MLLM model performance, including: ## Build Build from the mllm_v2 project root directory: - ```bash mkdir -p build && cd build cmake .. @@ -20,7 +19,6 @@ make mllm-llm-benchmark ## Usage ### Basic Usage - ```bash ./mllm-llm-benchmark \ -n qwen3-w4a32-kai \ @@ -32,6 +30,47 @@ make mllm-llm-benchmark -cl 2048 ``` +### Context Sweep (New Feature) + +For automated benchmarking across different context lengths, use the sweep script: +```bash +cd tools/mllm-llm-benchmark +chmod +x scripts/sweep_context_v2.sh + +# Configure paths +export BIN=../../build/bin/mllm-llm-benchmark +export MODEL=/path/to/your-model.mllm +export CFG=/path/to/config.json + +# Run sweep +./scripts/sweep_context_v2.sh +``` + +Output goes to `bench_context/context_sweep_v2.csv`. + +**Configuration options:** +- `BIN`: Path to benchmark binary (required) +- `MODEL`: Path to model file (required) +- `CFG`: Path to config json (default: `./examples/llama/config_tiny_llama.json`) +- `THREADS`: Number of threads (default: 8) +- `RUNS`: How many runs to average (default: 1) +- `COOLDOWN`: Seconds to wait between runs (default: 0) +- `CTX_LENS`: Context lengths to test (default: "256 512 1024 2048 4096") +- `TG_DH`: Generate length for decode_heavy mode (default: 256) +- `TG_TTFT`: Generate length for prefill_ttft mode (default: 2) +- `OUTDIR`: Output directory (default: bench_context) + +**Test modes:** +- `prefill_ttft`: Measures time to first token (prompt length = CTX_LEN-2, generates 2 tokens) +- `decode_heavy`: Measures decode throughput (prompt length = CTX_LEN-256, generates 256 tokens) + +### Plot Results + +Visualize benchmark results: +```bash +python3 scripts/plot_sweep.py bench_context/context_sweep_v2.csv output_dir/ +``` + ### Parameters | Parameter | Long Format | Description | Example | @@ -47,7 +86,6 @@ make mllm-llm-benchmark ### Examples #### Testing Qwen3-0.6B Model - ```bash ./mllm-llm-benchmark \ -n qwen3-w4a32-kai \ @@ -60,7 +98,6 @@ make mllm-llm-benchmark ``` #### Quick Test (Single Configuration) - ```bash ./mllm-llm-benchmark \ -n qwen3-w4a32-kai \ @@ -73,7 +110,6 @@ make mllm-llm-benchmark ``` ## Output Example - ``` MLLM Build Version : abc123def456 ARCH : ARM64 @@ -144,7 +180,6 @@ Each test configuration executes the following steps: ### 1. Create New Benchmark Class Create `YourModel_Benchmark.hpp` in the `models/` directory: - ```cpp #include "BenchmarkTemplate.hpp" #include @@ -178,7 +213,6 @@ class YourModel_Benchmark final : public BenchmarkTemplate { ``` ### 2. Register in All.hpp - ```cpp #include "YourModel_Benchmark.hpp" diff --git a/tools/mllm-llm-benchmark/main.cpp b/tools/mllm-llm-benchmark/main.cpp index af275a2e6..e80284ce7 100644 --- a/tools/mllm-llm-benchmark/main.cpp +++ b/tools/mllm-llm-benchmark/main.cpp @@ -1,10 +1,13 @@ // Copyright (c) MLLM Team. // Licensed under the MIT License. +#include +#include #include #include #include #include +#include #include #include @@ -16,6 +19,9 @@ #include "models/All.hpp" +#define STR_HELPER(x) #x +#define STR(x) STR_HELPER(x) + MLLM_MAIN({ auto& help = mllm::Argparse::add("-h|--help").help("Show help message"); auto& model_name = mllm::Argparse::add("-n|--model_name").help("Model name"); @@ -25,12 +31,21 @@ MLLM_MAIN({ auto& pp = mllm::Argparse::add("-pp|--prompt_length").help("Prompt length"); auto& tg = mllm::Argparse::add("-tg|--test_generation_length").help("Test Generation length"); auto& cache_length = mllm::Argparse::add("-cl|--cache_length").help("Cache length"); + + auto& runs = mllm::Argparse::add("-r|--runs").help("Number of benchmark runs").def(3); + auto& cooldown_s = mllm::Argparse::add("-cs|--cooldown_s").help("Cooldown time between runs in seconds").def(5); + auto& output_csv = mllm::Argparse::add("-oc|--output_csv").help("Output results to a CSV file").def(""); + auto& schema_version = mllm::Argparse::add("-sv|--schema_version").help("Schema version for output format").def(1); + auto& kv_dtype_bytes = + mllm::Argparse::add("-kv|--kv_dtype_bytes").help("KV cache data type bytes (1: int8, 2: fp16, 4: fp32)").def(4); + mllm::Argparse::parse(argc, argv); - // Print Build Version + mllm::Context::instance().setCpuOpThreads(num_threads.get()); + mllm::setMaximumNumThreads((uint32_t)num_threads.get()); + mllm::print("MLLM Build Version :", STRINGIFY(MLLM_GIT_COMMIT_HASH)); - // Print Device Info mllm::print("ARCH :", mllm::cpu::CURRENT_ARCH_STRING); mllm::print("FP16 :", mllm::cpu::hasFP16()); mllm::print("BF16 :", mllm::cpu::hasBF16()); @@ -53,15 +68,31 @@ MLLM_MAIN({ mllm::print("AVX512VL :", mllm::cpu::hasAVX512VL()); mllm::print("FMA :", mllm::cpu::hasFMA()); - // Create benchmark mllm::print("Create Benchmark: ", model_name.get()); auto benchmark = createBenchmark(model_name.get()); MLLM_RT_ASSERT(benchmark != nullptr); - // Print Model Info + int R = runs.get(); + if (R <= 0) { + mllm::print("[ERROR] --runs must be > 0, got:", R); + return 1; + } + + std::ofstream csv_file; + if (!output_csv.get().empty()) { + csv_file.open(output_csv.get()); + if (!csv_file.is_open()) { + mllm::print("[ERROR] Failed to open --output_csv:", output_csv.get()); + return 1; + } + csv_file << "schema_version,git_commit,arch,model_name,cache_length,pp,tg,ttft_ms,prefill_speed,decode_speed,prefill_ms,decode_ms_per_" + "tok,kv_est_bytes_pp,kv_est_bytes_final\n"; + } + mllm::print("Model Info"); benchmark->init(config_path.get(), model_path.get(), cache_length.get()); benchmark->printModelInfo(); + mllm::print("Cache Length :", cache_length.get()); // Warmup run mllm::print("Warmup Run"); @@ -92,7 +123,7 @@ MLLM_MAIN({ for (size_t i = 0; i < pp_values.size(); ++i) { pp_tg_pairs.emplace_back(pp_values[i], tg_values[i]); } } - // Actual run for 3 turns and gives avg results. Each turn will sleep for 5 seconds to let the SoC or GPU/NPU cool down. + // Actual run for configurable number of turns mllm::print("\n========================================"); mllm::print("Starting Benchmark Tests"); mllm::print("========================================\n"); @@ -104,17 +135,13 @@ MLLM_MAIN({ mllm::print(" Generation Length (TG):", tg); mllm::print("----------------------------------------"); - // Storage for results std::vector results; - results.reserve(3); + results.reserve(static_cast(R)); - for (int i = 0; i < 3; ++i) { - mllm::print(" Run", i + 1, "of 3..."); + for (int i = 0; i < R; ++i) { + mllm::print(" Run", i + 1, "of", R, "..."); - // Clear cache before each run benchmark->clear(); - - // Run benchmark auto result = benchmark->run(pp, tg); results.push_back(result); @@ -122,14 +149,19 @@ MLLM_MAIN({ mllm::print(" Prefill Speed:", result.prefill_speed, "tokens/s"); mllm::print(" Decode Speed :", result.decode_speed, "tokens/s"); - // Sleep for 5 seconds between runs to cool down - if (i < 2) { - mllm::print(" Cooling down for 5 seconds..."); - std::this_thread::sleep_for(std::chrono::seconds(5)); + float prefill_ms = (result.prefill_speed > 0.0f) ? (pp / result.prefill_speed) * 1000.0f : 0.0f; + float decode_ms_per_tok = (result.decode_speed > 0.0f) ? (1.0f / result.decode_speed) * 1000.0f : 0.0f; + mllm::print(" Prefill Latency :", prefill_ms, "ms"); + mllm::print(" Decode Latency :", decode_ms_per_tok, "ms"); + + int cool = cooldown_s.get(); + if (i + 1 < R && cool > 0) { + mllm::print(" Cooling down for", cool, "seconds..."); + std::this_thread::sleep_for(std::chrono::seconds(cool)); } } - // Calculate average results + float denom = (R > 0) ? static_cast(R) : 1.0f; float avg_ttft = 0.0f; float avg_prefill_speed = 0.0f; float avg_decode_speed = 0.0f; @@ -151,9 +183,35 @@ MLLM_MAIN({ mllm::print("Average Prefill Speed:", avg_prefill_speed, "tokens/s"); mllm::print("Average Decode Speed :", avg_decode_speed, "tokens/s"); mllm::print("=====================================\n"); + + avg_ttft /= denom; + avg_prefill_speed /= denom; + avg_decode_speed /= denom; + + float avg_prefill_ms = (avg_prefill_speed > 0.0f) ? (pp / avg_prefill_speed) * 1000.0f : 0.0f; + float avg_decode_ms_per_tok = (avg_decode_speed > 0.0f) ? (1.0f / avg_decode_speed) * 1000.0f : 0.0f; + + // KV cache estimate + double kv_est_bytes_pp = 0.0; + double kv_est_bytes_final = 0.0; + if (auto info = benchmark->kvEstimateInfo(); info.has_value()) { + const int32_t bytes_per = kv_dtype_bytes.get(); // 1/2/4 + // LLaMA-like KV: 2 * n_layers * n_kv_heads * head_dim * seq_len * bytes + kv_est_bytes_pp = 2.0 * info->num_layers * info->num_kv_heads * info->head_dim * (double)pp * bytes_per; + kv_est_bytes_final = 2.0 * info->num_layers * info->num_kv_heads * info->head_dim * (double)(pp + tg) * bytes_per; + } + + std::stringstream ss; + ss << schema_version.get() << "," << STRINGIFY(MLLM_GIT_COMMIT_HASH) << "," << mllm::cpu::CURRENT_ARCH_STRING << "," + << model_name.get() << "," << cache_length.get() << "," << pp << "," << tg << "," << avg_ttft << "," << avg_prefill_speed << "," << avg_decode_speed + << "," << avg_prefill_ms << "," << avg_decode_ms_per_tok << "," << kv_est_bytes_pp << "," << kv_est_bytes_final; + + if (csv_file.is_open()) { csv_file << ss.str() << std::endl; } } mllm::print("\n========================================"); mllm::print("Benchmark Tests Completed"); mllm::print("========================================"); + + if (csv_file.is_open()) { csv_file.close(); } }) diff --git a/tools/mllm-llm-benchmark/models/All.hpp b/tools/mllm-llm-benchmark/models/All.hpp index 340fe6bf8..c8c528494 100644 --- a/tools/mllm-llm-benchmark/models/All.hpp +++ b/tools/mllm-llm-benchmark/models/All.hpp @@ -4,20 +4,33 @@ #include #include +#include +#include -#include "Qwen3_W4A32_KAI.hpp" #include "BenchmarkTemplate.hpp" +#include "Qwen3_W4A32_KAI.hpp" +#include "Llama.hpp" -std::shared_ptr createBenchmark(const std::string& model_name) { +inline std::shared_ptr createBenchmark(const std::string& model_name) { auto tolower = [](const std::string& str) { std::string result = str; - std::transform(result.begin(), result.end(), result.begin(), ::tolower); + // unsigned char cast to avoid UB + std::transform(result.begin(), result.end(), result.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); return result; }; + auto normalized_model_name = tolower(model_name); + if (normalized_model_name.find("qwen3") != std::string::npos && normalized_model_name.find("w4a32") != std::string::npos && normalized_model_name.find("kai") != std::string::npos) { return std::make_shared(); } + + if (normalized_model_name.find("llama") != std::string::npos || normalized_model_name.find("tinyllama") != std::string::npos + || normalized_model_name.find("tiny_llama") != std::string::npos) { + return std::make_shared(); + } + return nullptr; } diff --git a/tools/mllm-llm-benchmark/models/BenchmarkTemplate.hpp b/tools/mllm-llm-benchmark/models/BenchmarkTemplate.hpp index 4724a8ca8..d1f21b014 100644 --- a/tools/mllm-llm-benchmark/models/BenchmarkTemplate.hpp +++ b/tools/mllm-llm-benchmark/models/BenchmarkTemplate.hpp @@ -3,19 +3,27 @@ #pragma once #include +#include +#include /** * @brief Benchmark result structure */ struct BenchmarkTemplateResult { - float ttft; ///< Time To First Token in milliseconds - float prefill_speed; ///< Prefill phase speed in tokens/s - float decode_speed; ///< Decode phase speed in tokens/s + float ttft; ///< Time To First Token in milliseconds + float prefill_speed; ///< Prefill phase speed in tokens/s + float decode_speed; ///< Decode phase speed in tokens/s +}; + +struct KVCacheEstimateInfo { + int32_t num_layers = 0; + int32_t num_kv_heads = 0; + int32_t head_dim = 0; // hidden_size / num_attention_heads }; /** * @brief Base class for benchmark templates - * + * * All model benchmark implementations should inherit from this class and implement all virtual functions. */ class BenchmarkTemplate { @@ -32,21 +40,21 @@ class BenchmarkTemplate { /** * @brief Print model information - * + * * Should output model key parameters such as number of layers, hidden size, attention heads, etc. */ virtual void printModelInfo() = 0; /** * @brief Warmup run - * + * * Run the model once with small-scale input to ensure the model enters a stable state. */ virtual void warmup() = 0; /** * @brief Clear cache - * + * * Clear KV cache and performance counters to prepare for the next test. */ virtual void clear() = 0; @@ -58,4 +66,7 @@ class BenchmarkTemplate { * @return Test results */ virtual BenchmarkTemplateResult run(int32_t pp, int32_t tg) = 0; + + // KV cache size estimation; return nullopt if unsupported + virtual std::optional kvEstimateInfo() const { return std::nullopt; } }; diff --git a/tools/mllm-llm-benchmark/models/Llama.hpp b/tools/mllm-llm-benchmark/models/Llama.hpp new file mode 100644 index 000000000..63ae2bbb9 --- /dev/null +++ b/tools/mllm-llm-benchmark/models/Llama.hpp @@ -0,0 +1,134 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +#include "BenchmarkTemplate.hpp" + +#include +#include +#include + +class Llama_Benchmark final : public BenchmarkTemplate { + public: + std::optional kvEstimateInfo() const override { + if (!cfg_) return std::nullopt; + KVCacheEstimateInfo info; + info.num_layers = cfg_->num_hidden_layers; + info.num_kv_heads = cfg_->num_key_value_heads; + info.head_dim = cfg_->hidden_size / cfg_->num_attention_heads; + return info; + } + + void init(const std::string& cfg_path, const std::string& model_path, int32_t cache_length) override { + cfg_ = std::make_unique(cfg_path); + + // LLaMA config uses max_position_embeddings as KV-cache upper bound + if (cache_length > 0) { cfg_->max_position_embeddings = cache_length; } + + model_ = std::make_unique("", *cfg_); + + // V1 param file only + auto param = mllm::load(model_path, mllm::ModelFileVersion::kV1); + model_->load(param); + + mllm::print("Model initialized successfully"); + } + + void printModelInfo() override { + if (!cfg_) return; + mllm::print("========== Model Information =========="); + mllm::print("Model Type : LLaMA / TinyLlama"); + mllm::print("Hidden Size :", cfg_->hidden_size); + mllm::print("Num Layers :", cfg_->num_hidden_layers); + mllm::print("Num Heads :", cfg_->num_attention_heads); + mllm::print("Num KV Heads :", cfg_->num_key_value_heads); + int32_t head_dim = (cfg_->num_attention_heads > 0) ? (cfg_->hidden_size / cfg_->num_attention_heads) : 0; + mllm::print("Head Dim :", head_dim); + mllm::print("Intermediate Size :", cfg_->intermediate_size); + mllm::print("Vocab Size :", cfg_->vocab_size); + mllm::print("Max Pos Embeddings :", cfg_->max_position_embeddings); + mllm::print("======================================="); + } + + void warmup() override { + if (!model_) return; + + const int32_t warmup_length = 8; + const int32_t warmup_gen = 4; + + auto input_ids = mllm::Tensor::empty({1, warmup_length}, mllm::kInt64, mllm::kCPU).setMemType(mllm::kNormal).alloc(); + auto ptr = input_ids.ptr(); + for (int i = 0; i < warmup_length; ++i) ptr[i] = 1; + + mllm::models::ARGenerationOutputPast inputs; + inputs["sequence"] = input_ids; + + mllm::models::ARGenerationArgs args; + args["max_length"] = mllm::AnyValue((int)warmup_gen); + args["do_sample"] = mllm::AnyValue(false); + + model_->generate(inputs, args); + mllm::print("Warmup completed"); + } + + void clear() override { + if (!model_) { + return; + } + model_->kvCache().setCurrentSeqCnt(0); + } + + BenchmarkTemplateResult run(int32_t pp, int32_t tg) override { + if (pp <= 0 || tg < 0) { + mllm::print("[ERROR] invalid pp/tg:", pp, tg); + return {0.f, 0.f, 0.f}; + } + if (!model_) return {0.f, 0.f, 0.f}; + + auto input_ids = mllm::Tensor::empty({1, pp}, mllm::kInt64, mllm::kCPU).setMemType(mllm::kNormal).alloc(); + auto ptr = input_ids.ptr(); + for (int i = 0; i < pp; ++i) ptr[i] = 1 + (i % 100); + + mllm::models::ARGenerationOutputPast inputs; + inputs["sequence"] = input_ids; + + mllm::models::ARGenerationArgs args; + args["max_length"] = mllm::AnyValue((int)tg); + args["do_sample"] = mllm::AnyValue(false); + + auto prefill_start = std::chrono::high_resolution_clock::now(); + auto decode_start = prefill_start; + auto decode_end = prefill_start; + + bool first_token = true; + int token_count = 0; + + model_->streamGenerate(inputs, args, [&](int64_t /*token_id*/) { + if (first_token) { + decode_start = std::chrono::high_resolution_clock::now(); + first_token = false; + } + token_count++; + decode_end = std::chrono::high_resolution_clock::now(); + }); + + auto prefill_us = std::chrono::duration_cast(decode_start - prefill_start).count(); + auto decode_us = std::chrono::duration_cast(decode_end - decode_start).count(); + + BenchmarkTemplateResult r; + r.ttft = prefill_us / 1000.0f; + r.prefill_speed = (prefill_us > 0) ? (static_cast(pp) / prefill_us) * 1e6f : 0.f; + // exclude first token from decode throughput + int decode_tokens = (token_count > 0) ? (token_count - 1) : 0; + r.decode_speed = (decode_us > 0 && decode_tokens > 0) ? (static_cast(decode_tokens) / decode_us) * 1e6f : 0.f; + return r; + } + + private: + std::unique_ptr cfg_; + std::unique_ptr model_; +}; diff --git a/tools/mllm-llm-benchmark/scripts/plot_sweep.py b/tools/mllm-llm-benchmark/scripts/plot_sweep.py new file mode 100644 index 000000000..7af8c27b4 --- /dev/null +++ b/tools/mllm-llm-benchmark/scripts/plot_sweep.py @@ -0,0 +1,118 @@ +import sys, os, csv, math +import matplotlib.pyplot as plt + +csv_path = sys.argv[1] if len(sys.argv) > 1 else "bench_context/context_sweep_v2.csv" +out_dir = sys.argv[2] if len(sys.argv) > 2 else "snapshots" +os.makedirs(out_dir, exist_ok=True) + +def to_float(x): + try: return float(x) + except: return float("nan") + +def to_int(x): + try: return int(float(x)) + except: return 0 + +rows = [] +with open(csv_path, "r", newline="") as f: + r = csv.DictReader(f) + for row in r: + rows.append(row) + +# normalize numeric fields +num_fields = ["ctx_len","pp","tg","threads","ttft_ms","prefill_ms","decode_ms","decode_ms_per_tok","peak_rss_kb","kv_est_kb"] +for row in rows: + for k in num_fields: + if k in row: + row[k] = to_float(row[k]) + +# write summary +stamp = os.path.splitext(os.path.basename(csv_path))[0] +summary_path = os.path.join(out_dir, f"{stamp}.summary.csv") +fieldnames = ["ts","git","arch","model","mode","ctx_len","pp","tg","threads", + "ttft_ms","prefill_ms","decode_ms","decode_ms_per_tok", + "peak_rss_kb","kv_est_kb","peak_rss_gb","kv_est_mb"] +with open(summary_path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for row in sorted(rows, key=lambda x: (x.get("mode",""), x.get("ctx_len",0))): + peak_rss_kb = row.get("peak_rss_kb", float("nan")) + kv_est_kb = row.get("kv_est_kb", float("nan")) + out = {k: row.get(k, "") for k in fieldnames} + out["peak_rss_gb"] = (peak_rss_kb / (1024*1024)) if peak_rss_kb==peak_rss_kb else "" + out["kv_est_mb"] = (kv_est_kb / 1024) if kv_est_kb==kv_est_kb else "" + w.writerow(out) + +def plot_mode(mode, xkey, ykey, ylabel, fname): + xs, ys = [], [] + for row in rows: + if row.get("mode","") != mode: + continue + x = row.get(xkey, float("nan")) + y = row.get(ykey, float("nan")) + if x==x and y==y: + xs.append(x); ys.append(y) + if not xs: + return + pts = sorted(zip(xs, ys), key=lambda t: t[0]) + xs = [p[0] for p in pts] + ys = [p[1] for p in pts] + plt.figure() + plt.plot(xs, ys, marker="o") + plt.xlabel(xkey) + plt.ylabel(ylabel) + plt.xscale("log", base=2) + plt.grid(True, which="both", linestyle="--", linewidth=0.5) + plt.tight_layout() + plt.savefig(os.path.join(out_dir, fname), dpi=180) + plt.close() + +plot_mode("prefill_ttft", "ctx_len", "ttft_ms", "TTFT (ms)", f"{stamp}.prefill_ttft.ttft_ms.png") +plot_mode("prefill_ttft", "ctx_len", "prefill_ms", "Prefill latency (ms)", f"{stamp}.prefill_ttft.prefill_ms.png") +plot_mode("decode_heavy", "ctx_len", "decode_ms_per_tok", "Decode latency per token (ms)", f"{stamp}.decode_heavy.decode_ms_per_tok.png") +plot_mode("decode_heavy", "ctx_len", "decode_ms", "Decode latency total (ms)", f"{stamp}.decode_heavy.decode_ms.png") + +# memory plots +mem = {} +for row in rows: + ctx_len = row.get("ctx_len", float("nan")) + if ctx_len != ctx_len: + continue + ctx_len = int(ctx_len) + peak = row.get("peak_rss_kb", float("nan")) + kv = row.get("kv_est_kb", float("nan")) + cur = mem.get(ctx_len, {"peak": float("nan"), "kv": float("nan")}) + if peak==peak and (cur["peak"]!=cur["peak"] or peak>cur["peak"]): cur["peak"]=peak + if kv==kv and (cur["kv"]!=cur["kv"] or kv>cur["kv"]): cur["kv"]=kv + mem[ctx_len]=cur + +ctx_lens = sorted(mem.keys()) +peak_gb = [(mem[c]["peak"]/(1024*1024)) if mem[c]["peak"]==mem[c]["peak"] else float("nan") for c in ctx_lens] +kv_mb = [(mem[c]["kv"]/1024) if mem[c]["kv"]==mem[c]["kv"] else float("nan") for c in ctx_lens] + +plt.figure() +plt.plot(ctx_lens, peak_gb, marker="o") +plt.xlabel("ctx_len") +plt.ylabel("Peak RSS (GB)") +plt.xscale("log", base=2) +plt.grid(True, which="both", linestyle="--", linewidth=0.5) +plt.tight_layout() +plt.savefig(os.path.join(out_dir, f"{stamp}.memory.peak_rss_gb.png"), dpi=180) +plt.close() + +plt.figure() +plt.plot(ctx_lens, kv_mb, marker="o") +plt.xlabel("ctx_len") +plt.ylabel("KV estimate (MB)") +plt.xscale("log", base=2) +plt.grid(True, which="both", linestyle="--", linewidth=0.5) +plt.tight_layout() +plt.savefig(os.path.join(out_dir, f"{stamp}.memory.kv_est_mb.png"), dpi=180) +plt.close() + +print("Wrote:") +print(" ", summary_path) +print(" ", os.path.join(out_dir, f"{stamp}.prefill_ttft.ttft_ms.png")) +print(" ", os.path.join(out_dir, f"{stamp}.decode_heavy.decode_ms_per_tok.png")) +print(" ", os.path.join(out_dir, f"{stamp}.memory.peak_rss_gb.png")) +print(" ", os.path.join(out_dir, f"{stamp}.memory.kv_est_mb.png")) diff --git a/tools/mllm-llm-benchmark/scripts/sweep_context_v2.sh b/tools/mllm-llm-benchmark/scripts/sweep_context_v2.sh new file mode 100755 index 000000000..ab8b36325 --- /dev/null +++ b/tools/mllm-llm-benchmark/scripts/sweep_context_v2.sh @@ -0,0 +1,95 @@ +#!/usr/bin/env bash +set -euo pipefail + +BIN="${BIN:-./build/bin/mllm-llm-benchmark}" +MODEL="${MODEL:-}" +CFG="${CFG:-./examples/llama/config_tiny_llama.json}" +THREADS="${THREADS:-8}" +RUNS="${RUNS:-1}" +COOLDOWN="${COOLDOWN:-0}" +TG_DH="${TG_DH:-256}" +TG_TTFT="${TG_TTFT:-2}" +CTX_LENS="${CTX_LENS:-256 512 1024 2048 4096}" +OUTDIR="${OUTDIR:-bench_context}" + +mkdir -p "$OUTDIR" +OUTCSV="$OUTDIR/context_sweep_v2.csv" + +echo "ts,git,arch,model,mode,ctx_len,pp,tg,threads,ttft_ms,prefill_ms,decode_ms,decode_ms_per_tok,peak_rss_kb,kv_est_kb" > "$OUTCSV" + +GIT="$(git rev-parse --short=12 HEAD 2>/dev/null || echo NA)" +ARCH="$(uname -m)" +TS="$(date -Iseconds)" + +kv_est_kb() { + python3 - <<'PY' +import json, os, math +cfg = os.environ["CFG"] +ctx_len = int(os.environ["CTX_LEN"]) +bpe = int(os.environ.get("KV_BYTES_PER_ELEM","2")) +j = json.load(open(cfg,"r")) +L = int(j["num_hidden_layers"]) +H = int(j["num_attention_heads"]) +KVH = int(j.get("num_key_value_heads", H)) +hidden = int(j["hidden_size"]) +head_dim = hidden // H +kv_bytes = 2 * L * KVH * head_dim * ctx_len * bpe +print(int(math.ceil(kv_bytes / 1024))) +PY +} + +run_one () { + local mode="$1" + local ctx_len="$2" + local tg="$3" + + if (( ctx_len <= tg )); then + echo "skip: ctx_len=$ctx_len <= tg=$tg (mode=$mode)" + return 0 + fi + + local pp=$((ctx_len - tg)) + if (( pp < 1 )); then pp=1; fi + + echo "==== mode=$mode ctx_len=$ctx_len pp=$pp tg=$tg ====" + + local ALLLOG="$OUTDIR/run_${mode}_ctx${ctx_len}.all" + local TIMELOG="$OUTDIR/run_${mode}_ctx${ctx_len}.time" + + set +e + # pass ctx as cache limit + /usr/bin/time -v \ + "$BIN" -n tiny_llama -m "$MODEL" -c "$CFG" \ + -pp "$pp" -tg "$tg" -t "$THREADS" -cl "$ctx_len" -r "$RUNS" -cs "$COOLDOWN" \ + >"$ALLLOG" 2>"$TIMELOG" + local EXIT_CODE=$? + set -e + + if [ $EXIT_CODE -ne 0 ]; then + echo "run failed with exit code $EXIT_CODE: mode=$mode ctx_len=$ctx_len" + return 1 + fi + + local TTFT_MS PREFILL_MS DECODE_MS PEAK_RSS_KB KV_EST_KB + TTFT_MS="$(grep -oP 'TTFT\s*:\s*\K[0-9.]+' "$ALLLOG" | head -n 1 || echo 0)" + PREFILL_MS="$(grep -oP 'Prefill Latency\s*:\s*\K[0-9.]+' "$ALLLOG" | head -n 1 || echo 0)" + DECODE_MS="$(grep -oP 'Decode Latency\s*:\s*\K[0-9.]+' "$ALLLOG" | head -n 1 || echo 0)" + + local DECODE_PER_TOK + DECODE_PER_TOK="$(python3 -c "tg=float('$tg'); d=float('$DECODE_MS'); print(d/tg if tg>0 else 0.0)")" + + PEAK_RSS_KB="$(grep -oP 'Maximum resident set size \(kbytes\):\s*\K[0-9]+' "$TIMELOG" | head -n 1 || echo 0)" + KV_EST_KB="$(CFG="$CFG" CTX_LEN="$ctx_len" KV_BYTES_PER_ELEM="${KV_BYTES_PER_ELEM:-2}" kv_est_kb || echo 0)" + + echo "TTFT=$TTFT_MS ms Prefill=$PREFILL_MS ms Decode=$DECODE_MS ms Decode/tok=$DECODE_PER_TOK ms peakRSS=$PEAK_RSS_KB KB KV_est=$KV_EST_KB KB" + + echo "$TS,$GIT,$ARCH,tiny_llama,$mode,$ctx_len,$pp,$tg,$THREADS,$TTFT_MS,$PREFILL_MS,$DECODE_MS,$DECODE_PER_TOK,$PEAK_RSS_KB,$KV_EST_KB" >> "$OUTCSV" +} + +for CTX in $CTX_LENS; do + run_one "decode_heavy" "$CTX" "$TG_DH" + run_one "prefill_ttft" "$CTX" "$TG_TTFT" +done + +echo +echo "DONE -> $OUTCSV"