@@ -178,6 +178,7 @@ struct cmd_params {
178
178
std::vector<std::vector<float >> tensor_split;
179
179
std::vector<bool > use_mmap;
180
180
std::vector<bool > embeddings;
181
+ ggml_numa_strategy numa;
181
182
int reps;
182
183
bool verbose;
183
184
output_formats output_format;
@@ -200,6 +201,7 @@ static const cmd_params cmd_params_defaults = {
200
201
/* tensor_split */ {std::vector<float >(llama_max_devices (), 0 .0f )},
201
202
/* use_mmap */ {true },
202
203
/* embeddings */ {false },
204
+ /* numa */ GGML_NUMA_STRATEGY_DISABLED,
203
205
/* reps */ 5 ,
204
206
/* verbose */ false ,
205
207
/* output_format */ MARKDOWN
@@ -224,6 +226,7 @@ static void print_usage(int /* argc */, char ** argv) {
224
226
printf (" -nkvo, --no-kv-offload <0|1> (default: %s)\n " , join (cmd_params_defaults.no_kv_offload , " ," ).c_str ());
225
227
printf (" -fa, --flash-attn <0|1> (default: %s)\n " , join (cmd_params_defaults.flash_attn , " ," ).c_str ());
226
228
printf (" -mmp, --mmap <0|1> (default: %s)\n " , join (cmd_params_defaults.use_mmap , " ," ).c_str ());
229
+ printf (" --numa <distribute|isolate|numactl> (default: disabled)\n " );
227
230
printf (" -embd, --embeddings <0|1> (default: %s)\n " , join (cmd_params_defaults.embeddings , " ," ).c_str ());
228
231
printf (" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n " );
229
232
printf (" -r, --repetitions <n> (default: %d)\n " , cmd_params_defaults.reps );
@@ -396,6 +399,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
396
399
}
397
400
auto p = split<bool >(argv[i], split_delim);
398
401
params.no_kv_offload .insert (params.no_kv_offload .end (), p.begin (), p.end ());
402
+ } else if (arg == " --numa" ) {
403
+ if (++i >= argc) {
404
+ invalid_param = true ;
405
+ break ;
406
+ } else {
407
+ std::string value (argv[i]);
408
+ /* */ if (value == " distribute" || value == " " ) { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; }
409
+ else if (value == " isolate" ) { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
410
+ else if (value == " numactl" ) { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
411
+ else { invalid_param = true ; break ; }
412
+ }
399
413
} else if (arg == " -fa" || arg == " --flash-attn" ) {
400
414
if (++i >= argc) {
401
415
invalid_param = true ;
@@ -1215,6 +1229,7 @@ int main(int argc, char ** argv) {
1215
1229
llama_log_set (llama_null_log_callback, NULL );
1216
1230
}
1217
1231
llama_backend_init ();
1232
+ llama_numa_init (params.numa );
1218
1233
1219
1234
// initialize printer
1220
1235
std::unique_ptr<printer> p;
0 commit comments