Skip to content

Commit 52f1476

Browse files
authored
Merge pull request #641 from RyanMetcalfeInt8/ryan/stateful_llm_chat_mode_support
Ryan/stateful llm chat mode support
2 parents c2791bc + 852a20b commit 52f1476

File tree

8 files changed

+198
-32
lines changed

8 files changed

+198
-32
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,5 +550,11 @@ void BackendManager::ShutdownBackendManager() {
550550
concrete_backend_.reset();
551551
}
552552

553+
void BackendManager::RewindKVCache(size_t index) {
554+
if (concrete_backend_) {
555+
concrete_backend_->RewindKVCache(index);
556+
}
557+
}
558+
553559
} // namespace openvino_ep
554560
} // namespace onnxruntime

onnxruntime/core/providers/openvino/backend_manager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class BackendManager {
3030
SessionContext& GetSessionContext();
3131
Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph);
3232
ov::CompiledModel& GetOVCompiledModel();
33+
void RewindKVCache(size_t index);
3334

3435
private:
3536
std::unique_ptr<ONNX_NAMESPACE::ModelProto> GetModelProtoFromFusedNode(

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,13 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) {
358358
device_config.emplace(ov::inference_num_threads(session_context_.num_of_threads));
359359
}
360360

361+
void BasicBackend::RewindKVCache(size_t index) {
362+
OVInferRequestPtr infer_request;
363+
infer_request = inferRequestsQueue_->getIdleRequest();
364+
infer_request->RewindKVCache(index);
365+
inferRequestsQueue_->putIdleRequest(std::move(infer_request));
366+
}
367+
361368
// Starts an asynchronous inference request for data in slice indexed by batch_slice_idx on
362369
// an Infer Request indexed by infer_req_idx
363370
void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) {

onnxruntime/core/providers/openvino/backends/basic_backend.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class BasicBackend : public IBackend {
4141
ov::CompiledModel& GetOVCompiledModel() override {
4242
return exe_network_.Get();
4343
}
44+
void RewindKVCache(size_t index) override;
4445

4546
private:
4647
void PopulateCompiledDirectory(std::string, std::string&, std::string&, bool&);
@@ -78,7 +79,7 @@ class InferRequestsQueue {
7879
InferRequestsQueue(OVExeNetwork& net, size_t nireq, std::function<void(OVInferRequestPtr)> initializer) {
7980
OVInferRequestPtr infer_request;
8081
for (size_t id = 0; id < nireq; id++) {
81-
infer_request = std::make_shared<OVInferRequest>(net.CreateInferRequest());
82+
infer_request = net.CreateInferRequest();
8283
initializer(infer_request);
8384
infer_requests_.push_back(infer_request);
8485
}

onnxruntime/core/providers/openvino/ibackend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class IBackend {
1717
virtual void Infer(OrtKernelContext* context) = 0;
1818
virtual ov::CompiledModel& GetOVCompiledModel() = 0;
1919
virtual ~IBackend() = default;
20+
virtual void RewindKVCache(size_t index) {};
2021
};
2122
using ptr_stream_t = std::unique_ptr<std::istream>;
2223
class BackendFactory {

onnxruntime/core/providers/openvino/openvino_execution_provider.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,24 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span<const ch
294294
ov_compiled_model.set_property(ov::workload_type(workload_type));
295295
}
296296
}
297+
} else if (key == "kvcache_rewind") {
298+
// convert kvcache_rewind value to int64_t
299+
int64_t index;
300+
try {
301+
index = std::stoll(value);
302+
} catch (const std::exception& e) {
303+
LOGS_DEFAULT(WARNING) << "Could not convert kvcache_rewind value string to index. Exception: " + std::string(e.what());
304+
return Status::OK();
305+
}
306+
307+
// Trigger KVCache rewind for backed
308+
for (auto& backend : backend_managers_) {
309+
if (index >= 0) {
310+
backend.RewindKVCache(static_cast<size_t>(index));
311+
} else {
312+
LOGS_DEFAULT(WARNING) << "kvcache_rewind index is < 0: " << index;
313+
}
314+
}
297315
} else {
298316
// Handle unknown options
299317
LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value;

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 136 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
125125
compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config);
126126
std::cout << "Stateful OV Model Compilation Complete" << std::endl;
127127

128-
OVExeNetwork exe(compiled_model);
128+
OVExeNetwork exe(compiled_model, hw_target, true);
129129
return exe;
130130
}
131131

@@ -134,19 +134,18 @@ OVExeNetwork OVCore::CompileModel(std::shared_ptr<const OVNetwork>& ie_cnn_netwo
134134
ov::AnyMap& device_config,
135135
bool enable_causallm,
136136
const std::string& name) {
137-
ov::CompiledModel obj;
137+
OVExeNetwork exe;
138138
try {
139139
if (enable_causallm) {
140140
auto mutable_model = ie_cnn_network->clone();
141-
auto compiled_model = OVCore::Get()->StatefulCompileModel(mutable_model, hw_target, device_config);
142-
obj = compiled_model.Get();
141+
exe = OVCore::Get()->StatefulCompileModel(mutable_model, hw_target, device_config);
143142
} else {
144-
obj = core.compile_model(ie_cnn_network, hw_target, device_config);
143+
auto obj = core.compile_model(ie_cnn_network, hw_target, device_config);
144+
exe = OVExeNetwork(obj, hw_target);
145145
}
146146
#ifndef NDEBUG
147147
printDebugInfo(obj);
148148
#endif
149-
OVExeNetwork exe(obj);
150149
return exe;
151150
} catch (const Exception& e) {
152151
ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what());
@@ -165,7 +164,7 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model,
165164
#ifndef NDEBUG
166165
printDebugInfo(obj);
167166
#endif
168-
OVExeNetwork exe(obj);
167+
OVExeNetwork exe(obj, hw_target);
169168
return exe;
170169
} catch (const Exception& e) {
171170
ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what());
@@ -180,7 +179,7 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
180179
bool enable_causallm,
181180
std::string name) {
182181
try {
183-
ov::CompiledModel obj;
182+
OVExeNetwork exe;
184183

185184
// Check if it's XML
186185
std::streampos originalPos = model_stream.tellg();
@@ -194,7 +193,8 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
194193
model_stream.seekg(originalPos);
195194

196195
if (header != "<?xml") {
197-
obj = core.import_model(model_stream, hw_target, device_config);
196+
auto obj = core.import_model(model_stream, hw_target, device_config);
197+
exe = OVExeNetwork(obj, hw_target);
198198
} else {
199199
// Get path to bin file
200200
std::string bin_file;
@@ -232,17 +232,16 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
232232
std::shared_ptr<ov::Model> model = core.read_model(xml_content, weights_tensor);
233233

234234
if (enable_causallm) {
235-
auto compiled_model = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config);
236-
obj = compiled_model.Get();
235+
exe = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config);
237236
} else {
238-
obj = core.compile_model(model, hw_target, device_config);
237+
auto obj = core.compile_model(model, hw_target, device_config);
238+
exe = OVExeNetwork(obj, hw_target);
239239
}
240240
}
241241

242242
#ifndef NDEBUG
243243
printDebugInfo(obj);
244244
#endif
245-
OVExeNetwork exe(obj);
246245
return exe;
247246
} catch (const Exception& e) {
248247
ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what());
@@ -330,11 +329,16 @@ void OVCore::SetStreams(const std::string& device_type, int num_streams) {
330329
core.set_property(device_type, {ov::num_streams(num_streams)});
331330
}
332331

333-
OVInferRequest OVExeNetwork::CreateInferRequest() {
332+
std::shared_ptr<OVInferRequest> OVExeNetwork::CreateInferRequest() {
334333
try {
335334
auto infReq = obj.create_infer_request();
336-
OVInferRequest inf_obj(std::move(infReq));
337-
return inf_obj;
335+
std::shared_ptr<OVInferRequest> ovInfReq;
336+
if (_stateful_llm) {
337+
ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move(infReq), _device);
338+
} else {
339+
ovInfReq = std::make_shared<OVInferRequest>(std::move(infReq));
340+
}
341+
return ovInfReq;
338342
} catch (const Exception& e) {
339343
ORT_THROW(log_tag + "Exception while creating InferRequest object: " + e.what());
340344
} catch (...) {
@@ -368,16 +372,6 @@ std::string OVInferRequest::GetInputTensorName(uint32_t index) {
368372
void OVInferRequest::SetTensor(const std::string& name, OVTensorPtr& blob) {
369373
try {
370374
ovInfReq.set_tensor(name, *(blob.get()));
371-
372-
if (name == "input_ids") {
373-
// Since we can't seem to set at ORT GenAI layer right now, we just set it here
374-
// as a workaround.
375-
// TODO: Fix this.
376-
ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {1});
377-
std::fill_n(beam_idx.data<int32_t>(), 1, 0);
378-
ovInfReq.set_tensor("beam_idx", beam_idx);
379-
}
380-
381375
} catch (const Exception& e) {
382376
ORT_THROW(log_tag + " Cannot set Remote Blob for output: " + name + e.what());
383377
} catch (...) {
@@ -423,5 +417,121 @@ void OVInferRequest::QueryStatus() {
423417
std::cout << "ovInfReq.query_state()"
424418
<< " ";
425419
}
420+
421+
void StatefulOVInferRequest::_pre_infer() {
422+
// Since we can't seem to set at ORT GenAI layer right now, we just set it here
423+
// as a workaround.
424+
// TODO: Fix this.
425+
ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {1});
426+
std::fill_n(beam_idx.data<int32_t>(), 1, 0);
427+
ovInfReq.set_tensor("beam_idx", beam_idx);
428+
429+
// For NPU, we need to cache input_ids and position_ids for
430+
// chat-mode support.
431+
if (device.find("NPU") != std::string::npos) {
432+
auto input_ids_tensor = ovInfReq.get_tensor("input_ids");
433+
434+
// add input_ids to our cache
435+
{
436+
auto* pData = input_ids_tensor.data<int64_t>();
437+
for (size_t i = 0; i < input_ids_tensor.get_size(); i++) {
438+
cached_input_ids.push_back(pData[i]);
439+
}
440+
}
441+
442+
// add position_ids to our cache
443+
{
444+
auto position_ids = ovInfReq.get_tensor("position_ids");
445+
auto* pData = position_ids.data<int64_t>();
446+
for (size_t i = 0; i < position_ids.get_size(); i++) {
447+
cached_position_ids.push_back(pData[i]);
448+
}
449+
}
450+
451+
// if we're about to run prefill model
452+
if (input_ids_tensor.get_size() > 1) {
453+
// if the input_ids size doesn't equal cached size of the input_ids
454+
// then it means that we're running 2nd (or later) prompt.
455+
if (input_ids_tensor.get_shape()[1] != cached_input_ids.size()) {
456+
// set a new input_ids tensor with the content of our cached input_ids
457+
{
458+
auto new_shape = input_ids_tensor.get_shape();
459+
new_shape[1] = cached_input_ids.size();
460+
auto new_input_ids = ov::Tensor(input_ids_tensor.get_element_type(), new_shape);
461+
auto* pNewInputIds = new_input_ids.data<int64_t>();
462+
std::memcpy(pNewInputIds, cached_input_ids.data(), cached_input_ids.size() * sizeof(int64_t));
463+
ovInfReq.set_tensor("input_ids", new_input_ids);
464+
}
465+
466+
// set a new position_ids tensor with the content of our cached position_ids
467+
{
468+
auto position_ids_tensor = ovInfReq.get_tensor("position_ids");
469+
auto new_shape = position_ids_tensor.get_shape();
470+
new_shape[1] = cached_position_ids.size();
471+
auto new_position_ids = ov::Tensor(position_ids_tensor.get_element_type(), new_shape);
472+
auto* pNewPositionIds = new_position_ids.data<int64_t>();
473+
std::memcpy(pNewPositionIds, cached_position_ids.data(), cached_position_ids.size() * sizeof(int64_t));
474+
ovInfReq.set_tensor("position_ids", new_position_ids);
475+
}
476+
}
477+
}
478+
}
479+
}
480+
481+
void StatefulOVInferRequest::StartAsync() {
482+
_pre_infer();
483+
OVInferRequest::StartAsync();
484+
}
485+
486+
void StatefulOVInferRequest::Infer() {
487+
_pre_infer();
488+
OVInferRequest::Infer();
489+
}
490+
491+
void StatefulOVInferRequest::RewindKVCache(size_t index) {
492+
if (device == "NPU") {
493+
std::cout << "RewindKVCache on NPU: Trimming cached input_ids / position_ids to length "
494+
<< index << std::endl;
495+
if (cached_input_ids.size() > index) {
496+
cached_input_ids.resize(index);
497+
}
498+
499+
if (cached_position_ids.size() > index) {
500+
cached_position_ids.resize(index);
501+
}
502+
} else {
503+
std::cout << "OVInferRequest::RewindKVCache: Trimming internal states to length = "
504+
<< index << std::endl;
505+
if (index == 0) {
506+
// in this case, since we're trimming *all* of the KVCache, just reset the state.
507+
ovInfReq.reset_state();
508+
} else {
509+
// retrieve kvcache states, and trim...
510+
// Most of this code was grabbed from here:
511+
// https://github.com/openvinotoolkit/openvino.genai/blob/releases/2025/1/src/cpp/src/utils.cpp#L329
512+
auto states = ovInfReq.query_state();
513+
for (auto& state : states) {
514+
ov::Tensor old_tensor = state.get_state();
515+
// [BATCH_SIZE, num_kv_heads, seq_len, head_size]
516+
auto shape = old_tensor.get_shape();
517+
518+
if (shape[2] > index) {
519+
shape[2] = index;
520+
521+
ov::Coordinate new_shape_begin{0, 0, 0, 0};
522+
ov::Coordinate new_shape_end{shape};
523+
524+
auto trimmed_tensor = ov::Tensor(old_tensor, new_shape_begin, new_shape_end);
525+
526+
ov::Tensor new_tensor(old_tensor.get_element_type(), shape);
527+
trimmed_tensor.copy_to(new_tensor);
528+
529+
state.set_state(new_tensor);
530+
}
531+
}
532+
}
533+
}
534+
}
535+
426536
} // namespace openvino_ep
427537
} // namespace onnxruntime

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,31 +105,53 @@ struct OVCore : WeakSingleton<OVCore> {
105105

106106
class OVExeNetwork {
107107
ov::CompiledModel obj;
108-
108+
std::string _device;
109+
bool _stateful_llm;
109110
public:
110-
explicit OVExeNetwork(ov::CompiledModel md) : obj(md) {}
111+
explicit OVExeNetwork(ov::CompiledModel md, std::string device, bool stateful_llm = false)
112+
: obj(md), _device(device), _stateful_llm(stateful_llm) {}
111113
OVExeNetwork() : obj(ov::CompiledModel()) {}
112114
ov::CompiledModel& Get() { return obj; }
113-
OVInferRequest CreateInferRequest();
115+
std::shared_ptr<OVInferRequest> CreateInferRequest();
114116
};
115117

116118
class OVInferRequest {
119+
protected:
117120
ov::InferRequest ovInfReq;
118121

119122
public:
120123
uint32_t GetNumInputs();
121124
OVTensorPtr GetTensor(const std::string& name);
122125
std::string GetInputTensorName(uint32_t index);
123126
void SetTensor(const std::string& name, OVTensorPtr& blob);
124-
void StartAsync();
125-
void Infer();
127+
virtual void StartAsync();
128+
virtual void Infer();
126129
void WaitRequest();
127130
void QueryStatus();
128131
explicit OVInferRequest(ov::InferRequest obj) : ovInfReq(std::move(obj)) {}
129132
OVInferRequest() : ovInfReq(ov::InferRequest()) {}
130133
ov::InferRequest& GetNewObj() {
131134
return ovInfReq;
132135
}
136+
virtual void RewindKVCache(size_t index) {};
133137
};
138+
139+
class StatefulOVInferRequest : public OVInferRequest {
140+
public:
141+
explicit StatefulOVInferRequest(ov::InferRequest obj, std::string d) : OVInferRequest(std::move(obj)), device(d) {}
142+
143+
void StartAsync() override;
144+
void Infer() override;
145+
void RewindKVCache(size_t index) override;
146+
147+
private:
148+
void _pre_infer();
149+
std::string device;
150+
151+
// For NPU, we need to cache input_ids & position_ids to support chat-mode.
152+
std::vector<int64_t> cached_input_ids;
153+
std::vector<int64_t> cached_position_ids;
154+
};
155+
134156
} // namespace openvino_ep
135157
} // namespace onnxruntime

0 commit comments

Comments
 (0)