@@ -125,7 +125,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
125
125
compiled_model = OVCore::Get ()->core .compile_model (model, hw_target, config);
126
126
std::cout << " Stateful OV Model Compilation Complete" << std::endl;
127
127
128
- OVExeNetwork exe (compiled_model);
128
+ OVExeNetwork exe (compiled_model, hw_target, true );
129
129
return exe;
130
130
}
131
131
@@ -134,19 +134,18 @@ OVExeNetwork OVCore::CompileModel(std::shared_ptr<const OVNetwork>& ie_cnn_netwo
134
134
ov::AnyMap& device_config,
135
135
bool enable_causallm,
136
136
const std::string& name) {
137
- ov::CompiledModel obj ;
137
+ OVExeNetwork exe ;
138
138
try {
139
139
if (enable_causallm) {
140
140
auto mutable_model = ie_cnn_network->clone ();
141
- auto compiled_model = OVCore::Get ()->StatefulCompileModel (mutable_model, hw_target, device_config);
142
- obj = compiled_model.Get ();
141
+ exe = OVCore::Get ()->StatefulCompileModel (mutable_model, hw_target, device_config);
143
142
} else {
144
- obj = core.compile_model (ie_cnn_network, hw_target, device_config);
143
+ auto obj = core.compile_model (ie_cnn_network, hw_target, device_config);
144
+ exe = OVExeNetwork (obj, hw_target);
145
145
}
146
146
#ifndef NDEBUG
147
147
printDebugInfo (obj);
148
148
#endif
149
- OVExeNetwork exe (obj);
150
149
return exe;
151
150
} catch (const Exception& e) {
152
151
ORT_THROW (log_tag + " Exception while Loading Network for graph: " + name + e.what ());
@@ -165,7 +164,7 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model,
165
164
#ifndef NDEBUG
166
165
printDebugInfo (obj);
167
166
#endif
168
- OVExeNetwork exe (obj);
167
+ OVExeNetwork exe (obj, hw_target );
169
168
return exe;
170
169
} catch (const Exception& e) {
171
170
ORT_THROW (log_tag + " Exception while Loading Network for graph: " + name + e.what ());
@@ -180,7 +179,7 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
180
179
bool enable_causallm,
181
180
std::string name) {
182
181
try {
183
- ov::CompiledModel obj ;
182
+ OVExeNetwork exe ;
184
183
185
184
// Check if it's XML
186
185
std::streampos originalPos = model_stream.tellg ();
@@ -194,7 +193,8 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
194
193
model_stream.seekg (originalPos);
195
194
196
195
if (header != " <?xml" ) {
197
- obj = core.import_model (model_stream, hw_target, device_config);
196
+ auto obj = core.import_model (model_stream, hw_target, device_config);
197
+ exe = OVExeNetwork (obj, hw_target);
198
198
} else {
199
199
// Get path to bin file
200
200
std::string bin_file;
@@ -232,17 +232,16 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
232
232
std::shared_ptr<ov::Model> model = core.read_model (xml_content, weights_tensor);
233
233
234
234
if (enable_causallm) {
235
- auto compiled_model = OVCore::Get ()->StatefulCompileModel (model, hw_target, device_config);
236
- obj = compiled_model.Get ();
235
+ exe = OVCore::Get ()->StatefulCompileModel (model, hw_target, device_config);
237
236
} else {
238
- obj = core.compile_model (model, hw_target, device_config);
237
+ auto obj = core.compile_model (model, hw_target, device_config);
238
+ exe = OVExeNetwork (obj, hw_target);
239
239
}
240
240
}
241
241
242
242
#ifndef NDEBUG
243
243
printDebugInfo (obj);
244
244
#endif
245
- OVExeNetwork exe (obj);
246
245
return exe;
247
246
} catch (const Exception& e) {
248
247
ORT_THROW (log_tag + " Exception while Loading Network for graph: " + name + e.what ());
@@ -330,11 +329,16 @@ void OVCore::SetStreams(const std::string& device_type, int num_streams) {
330
329
core.set_property (device_type, {ov::num_streams (num_streams)});
331
330
}
332
331
333
- OVInferRequest OVExeNetwork::CreateInferRequest () {
332
+ std::shared_ptr< OVInferRequest> OVExeNetwork::CreateInferRequest () {
334
333
try {
335
334
auto infReq = obj.create_infer_request ();
336
- OVInferRequest inf_obj (std::move (infReq));
337
- return inf_obj;
335
+ std::shared_ptr<OVInferRequest> ovInfReq;
336
+ if (_stateful_llm) {
337
+ ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move (infReq), _device);
338
+ } else {
339
+ ovInfReq = std::make_shared<OVInferRequest>(std::move (infReq));
340
+ }
341
+ return ovInfReq;
338
342
} catch (const Exception& e) {
339
343
ORT_THROW (log_tag + " Exception while creating InferRequest object: " + e.what ());
340
344
} catch (...) {
@@ -368,16 +372,6 @@ std::string OVInferRequest::GetInputTensorName(uint32_t index) {
368
372
void OVInferRequest::SetTensor (const std::string& name, OVTensorPtr& blob) {
369
373
try {
370
374
ovInfReq.set_tensor (name, *(blob.get ()));
371
-
372
- if (name == " input_ids" ) {
373
- // Since we can't seem to set at ORT GenAI layer right now, we just set it here
374
- // as a workaround.
375
- // TODO: Fix this.
376
- ov::Tensor beam_idx = ov::Tensor (ov::element::i32 , {1 });
377
- std::fill_n (beam_idx.data <int32_t >(), 1 , 0 );
378
- ovInfReq.set_tensor (" beam_idx" , beam_idx);
379
- }
380
-
381
375
} catch (const Exception& e) {
382
376
ORT_THROW (log_tag + " Cannot set Remote Blob for output: " + name + e.what ());
383
377
} catch (...) {
@@ -423,5 +417,121 @@ void OVInferRequest::QueryStatus() {
423
417
std::cout << " ovInfReq.query_state()"
424
418
<< " " ;
425
419
}
420
+
421
+ void StatefulOVInferRequest::_pre_infer () {
422
+ // Since we can't seem to set at ORT GenAI layer right now, we just set it here
423
+ // as a workaround.
424
+ // TODO: Fix this.
425
+ ov::Tensor beam_idx = ov::Tensor (ov::element::i32 , {1 });
426
+ std::fill_n (beam_idx.data <int32_t >(), 1 , 0 );
427
+ ovInfReq.set_tensor (" beam_idx" , beam_idx);
428
+
429
+ // For NPU, we need to cache input_ids and position_ids for
430
+ // chat-mode support.
431
+ if (device.find (" NPU" ) != std::string::npos) {
432
+ auto input_ids_tensor = ovInfReq.get_tensor (" input_ids" );
433
+
434
+ // add input_ids to our cache
435
+ {
436
+ auto * pData = input_ids_tensor.data <int64_t >();
437
+ for (size_t i = 0 ; i < input_ids_tensor.get_size (); i++) {
438
+ cached_input_ids.push_back (pData[i]);
439
+ }
440
+ }
441
+
442
+ // add position_ids to our cache
443
+ {
444
+ auto position_ids = ovInfReq.get_tensor (" position_ids" );
445
+ auto * pData = position_ids.data <int64_t >();
446
+ for (size_t i = 0 ; i < position_ids.get_size (); i++) {
447
+ cached_position_ids.push_back (pData[i]);
448
+ }
449
+ }
450
+
451
+ // if we're about to run prefill model
452
+ if (input_ids_tensor.get_size () > 1 ) {
453
+ // if the input_ids size doesn't equal cached size of the input_ids
454
+ // then it means that we're running 2nd (or later) prompt.
455
+ if (input_ids_tensor.get_shape ()[1 ] != cached_input_ids.size ()) {
456
+ // set a new input_ids tensor with the content of our cached input_ids
457
+ {
458
+ auto new_shape = input_ids_tensor.get_shape ();
459
+ new_shape[1 ] = cached_input_ids.size ();
460
+ auto new_input_ids = ov::Tensor (input_ids_tensor.get_element_type (), new_shape);
461
+ auto * pNewInputIds = new_input_ids.data <int64_t >();
462
+ std::memcpy (pNewInputIds, cached_input_ids.data (), cached_input_ids.size () * sizeof (int64_t ));
463
+ ovInfReq.set_tensor (" input_ids" , new_input_ids);
464
+ }
465
+
466
+ // set a new position_ids tensor with the content of our cached position_ids
467
+ {
468
+ auto position_ids_tensor = ovInfReq.get_tensor (" position_ids" );
469
+ auto new_shape = position_ids_tensor.get_shape ();
470
+ new_shape[1 ] = cached_position_ids.size ();
471
+ auto new_position_ids = ov::Tensor (position_ids_tensor.get_element_type (), new_shape);
472
+ auto * pNewPositionIds = new_position_ids.data <int64_t >();
473
+ std::memcpy (pNewPositionIds, cached_position_ids.data (), cached_position_ids.size () * sizeof (int64_t ));
474
+ ovInfReq.set_tensor (" position_ids" , new_position_ids);
475
+ }
476
+ }
477
+ }
478
+ }
479
+ }
480
+
481
+ void StatefulOVInferRequest::StartAsync () {
482
+ _pre_infer ();
483
+ OVInferRequest::StartAsync ();
484
+ }
485
+
486
+ void StatefulOVInferRequest::Infer () {
487
+ _pre_infer ();
488
+ OVInferRequest::Infer ();
489
+ }
490
+
491
+ void StatefulOVInferRequest::RewindKVCache (size_t index) {
492
+ if (device == " NPU" ) {
493
+ std::cout << " RewindKVCache on NPU: Trimming cached input_ids / position_ids to length "
494
+ << index << std::endl;
495
+ if (cached_input_ids.size () > index) {
496
+ cached_input_ids.resize (index);
497
+ }
498
+
499
+ if (cached_position_ids.size () > index) {
500
+ cached_position_ids.resize (index);
501
+ }
502
+ } else {
503
+ std::cout << " OVInferRequest::RewindKVCache: Trimming internal states to length = "
504
+ << index << std::endl;
505
+ if (index == 0 ) {
506
+ // in this case, since we're trimming *all* of the KVCache, just reset the state.
507
+ ovInfReq.reset_state ();
508
+ } else {
509
+ // retrieve kvcache states, and trim...
510
+ // Most of this code was grabbed from here:
511
+ // https://github.com/openvinotoolkit/openvino.genai/blob/releases/2025/1/src/cpp/src/utils.cpp#L329
512
+ auto states = ovInfReq.query_state ();
513
+ for (auto & state : states) {
514
+ ov::Tensor old_tensor = state.get_state ();
515
+ // [BATCH_SIZE, num_kv_heads, seq_len, head_size]
516
+ auto shape = old_tensor.get_shape ();
517
+
518
+ if (shape[2 ] > index) {
519
+ shape[2 ] = index;
520
+
521
+ ov::Coordinate new_shape_begin{0 , 0 , 0 , 0 };
522
+ ov::Coordinate new_shape_end{shape};
523
+
524
+ auto trimmed_tensor = ov::Tensor (old_tensor, new_shape_begin, new_shape_end);
525
+
526
+ ov::Tensor new_tensor (old_tensor.get_element_type (), shape);
527
+ trimmed_tensor.copy_to (new_tensor);
528
+
529
+ state.set_state (new_tensor);
530
+ }
531
+ }
532
+ }
533
+ }
534
+ }
535
+
426
536
} // namespace openvino_ep
427
537
} // namespace onnxruntime
0 commit comments