Skip to content

Commit 852a20b

Browse files
Support KVCache rewind for stateful LLMs via SetEpDynamicOptions
1 parent f9d2381 commit 852a20b

File tree

8 files changed

+82
-1
lines changed

8 files changed

+82
-1
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,5 +550,11 @@ void BackendManager::ShutdownBackendManager() {
550550
concrete_backend_.reset();
551551
}
552552

553+
void BackendManager::RewindKVCache(size_t index) {
554+
if (concrete_backend_) {
555+
concrete_backend_->RewindKVCache(index);
556+
}
557+
}
558+
553559
} // namespace openvino_ep
554560
} // namespace onnxruntime

onnxruntime/core/providers/openvino/backend_manager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class BackendManager {
3030
SessionContext& GetSessionContext();
3131
Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph);
3232
ov::CompiledModel& GetOVCompiledModel();
33+
void RewindKVCache(size_t index);
3334

3435
private:
3536
std::unique_ptr<ONNX_NAMESPACE::ModelProto> GetModelProtoFromFusedNode(

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,13 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) {
358358
device_config.emplace(ov::inference_num_threads(session_context_.num_of_threads));
359359
}
360360

361+
void BasicBackend::RewindKVCache(size_t index) {
362+
OVInferRequestPtr infer_request;
363+
infer_request = inferRequestsQueue_->getIdleRequest();
364+
infer_request->RewindKVCache(index);
365+
inferRequestsQueue_->putIdleRequest(std::move(infer_request));
366+
}
367+
361368
// Starts an asynchronous inference request for data in slice indexed by batch_slice_idx on
362369
// an Infer Request indexed by infer_req_idx
363370
void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) {

onnxruntime/core/providers/openvino/backends/basic_backend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class BasicBackend : public IBackend {
4141
ov::CompiledModel& GetOVCompiledModel() override {
4242
return exe_network_.Get();
4343
}
44+
void RewindKVCache(size_t index) override;
4445

4546
private:
4647
void PopulateCompiledDirectory(std::string, std::string&, std::string&, bool&);

onnxruntime/core/providers/openvino/ibackend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class IBackend {
1717
virtual void Infer(OrtKernelContext* context) = 0;
1818
virtual ov::CompiledModel& GetOVCompiledModel() = 0;
1919
virtual ~IBackend() = default;
20+
virtual void RewindKVCache(size_t index) {};
2021
};
2122
using ptr_stream_t = std::unique_ptr<std::istream>;
2223
class BackendFactory {

onnxruntime/core/providers/openvino/openvino_execution_provider.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,24 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span<const ch
294294
ov_compiled_model.set_property(ov::workload_type(workload_type));
295295
}
296296
}
297+
} else if (key == "kvcache_rewind") {
298+
// convert kvcache_rewind value to int64_t
299+
int64_t index;
300+
try {
301+
index = std::stoll(value);
302+
} catch (const std::exception& e) {
303+
LOGS_DEFAULT(WARNING) << "Could not convert kvcache_rewind value string to index. Exception: " + std::string(e.what());
304+
return Status::OK();
305+
}
306+
307+
// Trigger KVCache rewind for backed
308+
for (auto& backend : backend_managers_) {
309+
if (index >= 0) {
310+
backend.RewindKVCache(static_cast<size_t>(index));
311+
} else {
312+
LOGS_DEFAULT(WARNING) << "kvcache_rewind index is < 0: " << index;
313+
}
314+
}
297315
} else {
298316
// Handle unknown options
299317
LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value;

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
105105

106106
if (hw_target.find("NPU") != std::string::npos) {
107107
KVDesc kv_desc;
108-
kv_desc.max_prompt_len = PopIntAndCast(config, "MAX_PROMPT_LEN").value_or(3072u);
108+
kv_desc.max_prompt_len = PopIntAndCast(config, "MAX_PROMPT_LEN").value_or(1024u);
109109
kv_desc.min_response_len = PopIntAndCast(config, "MIN_RESPONSE_LEN").value_or(128u);
110110

111111
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
@@ -488,5 +488,50 @@ void StatefulOVInferRequest::Infer() {
488488
OVInferRequest::Infer();
489489
}
490490

491+
void StatefulOVInferRequest::RewindKVCache(size_t index) {
492+
if (device == "NPU") {
493+
std::cout << "RewindKVCache on NPU: Trimming cached input_ids / position_ids to length "
494+
<< index << std::endl;
495+
if (cached_input_ids.size() > index) {
496+
cached_input_ids.resize(index);
497+
}
498+
499+
if (cached_position_ids.size() > index) {
500+
cached_position_ids.resize(index);
501+
}
502+
} else {
503+
std::cout << "OVInferRequest::RewindKVCache: Trimming internal states to length = "
504+
<< index << std::endl;
505+
if (index == 0) {
506+
// in this case, since we're trimming *all* of the KVCache, just reset the state.
507+
ovInfReq.reset_state();
508+
} else {
509+
// retrieve kvcache states, and trim...
510+
// Most of this code was grabbed from here:
511+
// https://github.com/openvinotoolkit/openvino.genai/blob/releases/2025/1/src/cpp/src/utils.cpp#L329
512+
auto states = ovInfReq.query_state();
513+
for (auto& state : states) {
514+
ov::Tensor old_tensor = state.get_state();
515+
// [BATCH_SIZE, num_kv_heads, seq_len, head_size]
516+
auto shape = old_tensor.get_shape();
517+
518+
if (shape[2] > index) {
519+
shape[2] = index;
520+
521+
ov::Coordinate new_shape_begin{0, 0, 0, 0};
522+
ov::Coordinate new_shape_end{shape};
523+
524+
auto trimmed_tensor = ov::Tensor(old_tensor, new_shape_begin, new_shape_end);
525+
526+
ov::Tensor new_tensor(old_tensor.get_element_type(), shape);
527+
trimmed_tensor.copy_to(new_tensor);
528+
529+
state.set_state(new_tensor);
530+
}
531+
}
532+
}
533+
}
534+
}
535+
491536
} // namespace openvino_ep
492537
} // namespace onnxruntime

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ class OVInferRequest {
133133
ov::InferRequest& GetNewObj() {
134134
return ovInfReq;
135135
}
136+
virtual void RewindKVCache(size_t index) {};
136137
};
137138

138139
class StatefulOVInferRequest : public OVInferRequest {
@@ -141,6 +142,7 @@ class StatefulOVInferRequest : public OVInferRequest {
141142

142143
void StartAsync() override;
143144
void Infer() override;
145+
void RewindKVCache(size_t index) override;
144146

145147
private:
146148
void _pre_infer();

0 commit comments

Comments
 (0)