Skip to content

Commit 628b299

Browse files
authored
Adding support for the --numa argument for llama-bench. (ggml-org#7080)
1 parent 8f8acc8 commit 628b299

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

examples/llama-bench/llama-bench.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ struct cmd_params {
178178
std::vector<std::vector<float>> tensor_split;
179179
std::vector<bool> use_mmap;
180180
std::vector<bool> embeddings;
181+
ggml_numa_strategy numa;
181182
int reps;
182183
bool verbose;
183184
output_formats output_format;
@@ -200,6 +201,7 @@ static const cmd_params cmd_params_defaults = {
200201
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
201202
/* use_mmap */ {true},
202203
/* embeddings */ {false},
204+
/* numa */ GGML_NUMA_STRATEGY_DISABLED,
203205
/* reps */ 5,
204206
/* verbose */ false,
205207
/* output_format */ MARKDOWN
@@ -224,6 +226,7 @@ static void print_usage(int /* argc */, char ** argv) {
224226
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
225227
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
226228
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
229+
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
227230
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
228231
printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
229232
printf(" -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
@@ -396,6 +399,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
396399
}
397400
auto p = split<bool>(argv[i], split_delim);
398401
params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
402+
} else if (arg == "--numa") {
403+
if (++i >= argc) {
404+
invalid_param = true;
405+
break;
406+
} else {
407+
std::string value(argv[i]);
408+
/**/ if (value == "distribute" || value == "" ) { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; }
409+
else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
410+
else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
411+
else { invalid_param = true; break; }
412+
}
399413
} else if (arg == "-fa" || arg == "--flash-attn") {
400414
if (++i >= argc) {
401415
invalid_param = true;
@@ -1215,6 +1229,7 @@ int main(int argc, char ** argv) {
12151229
llama_log_set(llama_null_log_callback, NULL);
12161230
}
12171231
llama_backend_init();
1232+
llama_numa_init(params.numa);
12181233

12191234
// initialize printer
12201235
std::unique_ptr<printer> p;

0 commit comments

Comments
 (0)