Skip to content

Commit f9d2381

Browse files
Support chat-mode for NPU
1 parent c2791bc commit f9d2381

File tree

3 files changed

+118
-33
lines changed

3 files changed

+118
-33
lines changed

onnxruntime/core/providers/openvino/backends/basic_backend.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class InferRequestsQueue {
7878
InferRequestsQueue(OVExeNetwork& net, size_t nireq, std::function<void(OVInferRequestPtr)> initializer) {
7979
OVInferRequestPtr infer_request;
8080
for (size_t id = 0; id < nireq; id++) {
81-
infer_request = std::make_shared<OVInferRequest>(net.CreateInferRequest());
81+
infer_request = net.CreateInferRequest();
8282
initializer(infer_request);
8383
infer_requests_.push_back(infer_request);
8484
}

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 92 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
105105

106106
if (hw_target.find("NPU") != std::string::npos) {
107107
KVDesc kv_desc;
108-
kv_desc.max_prompt_len = PopIntAndCast(config, "MAX_PROMPT_LEN").value_or(1024u);
108+
kv_desc.max_prompt_len = PopIntAndCast(config, "MAX_PROMPT_LEN").value_or(3072u);
109109
kv_desc.min_response_len = PopIntAndCast(config, "MIN_RESPONSE_LEN").value_or(128u);
110110

111111
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
@@ -125,7 +125,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
125125
compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config);
126126
std::cout << "Stateful OV Model Compilation Complete" << std::endl;
127127

128-
OVExeNetwork exe(compiled_model);
128+
OVExeNetwork exe(compiled_model, hw_target, true);
129129
return exe;
130130
}
131131

@@ -134,19 +134,18 @@ OVExeNetwork OVCore::CompileModel(std::shared_ptr<const OVNetwork>& ie_cnn_netwo
134134
ov::AnyMap& device_config,
135135
bool enable_causallm,
136136
const std::string& name) {
137-
ov::CompiledModel obj;
137+
OVExeNetwork exe;
138138
try {
139139
if (enable_causallm) {
140140
auto mutable_model = ie_cnn_network->clone();
141-
auto compiled_model = OVCore::Get()->StatefulCompileModel(mutable_model, hw_target, device_config);
142-
obj = compiled_model.Get();
141+
exe = OVCore::Get()->StatefulCompileModel(mutable_model, hw_target, device_config);
143142
} else {
144-
obj = core.compile_model(ie_cnn_network, hw_target, device_config);
143+
auto obj = core.compile_model(ie_cnn_network, hw_target, device_config);
144+
exe = OVExeNetwork(obj, hw_target);
145145
}
146146
#ifndef NDEBUG
147147
printDebugInfo(obj);
148148
#endif
149-
OVExeNetwork exe(obj);
150149
return exe;
151150
} catch (const Exception& e) {
152151
ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what());
@@ -165,7 +164,7 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model,
165164
#ifndef NDEBUG
166165
printDebugInfo(obj);
167166
#endif
168-
OVExeNetwork exe(obj);
167+
OVExeNetwork exe(obj, hw_target);
169168
return exe;
170169
} catch (const Exception& e) {
171170
ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what());
@@ -180,7 +179,7 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
180179
bool enable_causallm,
181180
std::string name) {
182181
try {
183-
ov::CompiledModel obj;
182+
OVExeNetwork exe;
184183

185184
// Check if it's XML
186185
std::streampos originalPos = model_stream.tellg();
@@ -194,7 +193,8 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
194193
model_stream.seekg(originalPos);
195194

196195
if (header != "<?xml") {
197-
obj = core.import_model(model_stream, hw_target, device_config);
196+
auto obj = core.import_model(model_stream, hw_target, device_config);
197+
exe = OVExeNetwork(obj, hw_target);
198198
} else {
199199
// Get path to bin file
200200
std::string bin_file;
@@ -232,17 +232,16 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
232232
std::shared_ptr<ov::Model> model = core.read_model(xml_content, weights_tensor);
233233

234234
if (enable_causallm) {
235-
auto compiled_model = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config);
236-
obj = compiled_model.Get();
235+
exe = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config);
237236
} else {
238-
obj = core.compile_model(model, hw_target, device_config);
237+
auto obj = core.compile_model(model, hw_target, device_config);
238+
exe = OVExeNetwork(obj, hw_target);
239239
}
240240
}
241241

242242
#ifndef NDEBUG
243243
printDebugInfo(obj);
244244
#endif
245-
OVExeNetwork exe(obj);
246245
return exe;
247246
} catch (const Exception& e) {
248247
ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what());
@@ -330,11 +329,16 @@ void OVCore::SetStreams(const std::string& device_type, int num_streams) {
330329
core.set_property(device_type, {ov::num_streams(num_streams)});
331330
}
332331

333-
OVInferRequest OVExeNetwork::CreateInferRequest() {
332+
std::shared_ptr<OVInferRequest> OVExeNetwork::CreateInferRequest() {
334333
try {
335334
auto infReq = obj.create_infer_request();
336-
OVInferRequest inf_obj(std::move(infReq));
337-
return inf_obj;
335+
std::shared_ptr<OVInferRequest> ovInfReq;
336+
if (_stateful_llm) {
337+
ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move(infReq), _device);
338+
} else {
339+
ovInfReq = std::make_shared<OVInferRequest>(std::move(infReq));
340+
}
341+
return ovInfReq;
338342
} catch (const Exception& e) {
339343
ORT_THROW(log_tag + "Exception while creating InferRequest object: " + e.what());
340344
} catch (...) {
@@ -368,16 +372,6 @@ std::string OVInferRequest::GetInputTensorName(uint32_t index) {
368372
void OVInferRequest::SetTensor(const std::string& name, OVTensorPtr& blob) {
369373
try {
370374
ovInfReq.set_tensor(name, *(blob.get()));
371-
372-
if (name == "input_ids") {
373-
// Since we can't seem to set at ORT GenAI layer right now, we just set it here
374-
// as a workaround.
375-
// TODO: Fix this.
376-
ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {1});
377-
std::fill_n(beam_idx.data<int32_t>(), 1, 0);
378-
ovInfReq.set_tensor("beam_idx", beam_idx);
379-
}
380-
381375
} catch (const Exception& e) {
382376
ORT_THROW(log_tag + " Cannot set Remote Blob for output: " + name + e.what());
383377
} catch (...) {
@@ -423,5 +417,76 @@ void OVInferRequest::QueryStatus() {
423417
std::cout << "ovInfReq.query_state()"
424418
<< " ";
425419
}
420+
421+
void StatefulOVInferRequest::_pre_infer() {
422+
// Since we can't seem to set at ORT GenAI layer right now, we just set it here
423+
// as a workaround.
424+
// TODO: Fix this.
425+
ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {1});
426+
std::fill_n(beam_idx.data<int32_t>(), 1, 0);
427+
ovInfReq.set_tensor("beam_idx", beam_idx);
428+
429+
// For NPU, we need to cache input_ids and position_ids for
430+
// chat-mode support.
431+
if (device.find("NPU") != std::string::npos) {
432+
auto input_ids_tensor = ovInfReq.get_tensor("input_ids");
433+
434+
// add input_ids to our cache
435+
{
436+
auto* pData = input_ids_tensor.data<int64_t>();
437+
for (size_t i = 0; i < input_ids_tensor.get_size(); i++) {
438+
cached_input_ids.push_back(pData[i]);
439+
}
440+
}
441+
442+
// add position_ids to our cache
443+
{
444+
auto position_ids = ovInfReq.get_tensor("position_ids");
445+
auto* pData = position_ids.data<int64_t>();
446+
for (size_t i = 0; i < position_ids.get_size(); i++) {
447+
cached_position_ids.push_back(pData[i]);
448+
}
449+
}
450+
451+
// if we're about to run prefill model
452+
if (input_ids_tensor.get_size() > 1) {
453+
// if the input_ids size doesn't equal cached size of the input_ids
454+
// then it means that we're running 2nd (or later) prompt.
455+
if (input_ids_tensor.get_shape()[1] != cached_input_ids.size()) {
456+
// set a new input_ids tensor with the content of our cached input_ids
457+
{
458+
auto new_shape = input_ids_tensor.get_shape();
459+
new_shape[1] = cached_input_ids.size();
460+
auto new_input_ids = ov::Tensor(input_ids_tensor.get_element_type(), new_shape);
461+
auto* pNewInputIds = new_input_ids.data<int64_t>();
462+
std::memcpy(pNewInputIds, cached_input_ids.data(), cached_input_ids.size() * sizeof(int64_t));
463+
ovInfReq.set_tensor("input_ids", new_input_ids);
464+
}
465+
466+
// set a new position_ids tensor with the content of our cached position_ids
467+
{
468+
auto position_ids_tensor = ovInfReq.get_tensor("position_ids");
469+
auto new_shape = position_ids_tensor.get_shape();
470+
new_shape[1] = cached_position_ids.size();
471+
auto new_position_ids = ov::Tensor(position_ids_tensor.get_element_type(), new_shape);
472+
auto* pNewPositionIds = new_position_ids.data<int64_t>();
473+
std::memcpy(pNewPositionIds, cached_position_ids.data(), cached_position_ids.size() * sizeof(int64_t));
474+
ovInfReq.set_tensor("position_ids", new_position_ids);
475+
}
476+
}
477+
}
478+
}
479+
}
480+
481+
void StatefulOVInferRequest::StartAsync() {
482+
_pre_infer();
483+
OVInferRequest::StartAsync();
484+
}
485+
486+
void StatefulOVInferRequest::Infer() {
487+
_pre_infer();
488+
OVInferRequest::Infer();
489+
}
490+
426491
} // namespace openvino_ep
427492
} // namespace onnxruntime

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,24 +105,27 @@ struct OVCore : WeakSingleton<OVCore> {
105105

106106
class OVExeNetwork {
107107
ov::CompiledModel obj;
108-
108+
std::string _device;
109+
bool _stateful_llm;
109110
public:
110-
explicit OVExeNetwork(ov::CompiledModel md) : obj(md) {}
111+
explicit OVExeNetwork(ov::CompiledModel md, std::string device, bool stateful_llm = false)
112+
: obj(md), _device(device), _stateful_llm(stateful_llm) {}
111113
OVExeNetwork() : obj(ov::CompiledModel()) {}
112114
ov::CompiledModel& Get() { return obj; }
113-
OVInferRequest CreateInferRequest();
115+
std::shared_ptr<OVInferRequest> CreateInferRequest();
114116
};
115117

116118
class OVInferRequest {
119+
protected:
117120
ov::InferRequest ovInfReq;
118121

119122
public:
120123
uint32_t GetNumInputs();
121124
OVTensorPtr GetTensor(const std::string& name);
122125
std::string GetInputTensorName(uint32_t index);
123126
void SetTensor(const std::string& name, OVTensorPtr& blob);
124-
void StartAsync();
125-
void Infer();
127+
virtual void StartAsync();
128+
virtual void Infer();
126129
void WaitRequest();
127130
void QueryStatus();
128131
explicit OVInferRequest(ov::InferRequest obj) : ovInfReq(std::move(obj)) {}
@@ -131,5 +134,22 @@ class OVInferRequest {
131134
return ovInfReq;
132135
}
133136
};
137+
138+
class StatefulOVInferRequest : public OVInferRequest {
139+
public:
140+
explicit StatefulOVInferRequest(ov::InferRequest obj, std::string d) : OVInferRequest(std::move(obj)), device(d) {}
141+
142+
void StartAsync() override;
143+
void Infer() override;
144+
145+
private:
146+
void _pre_infer();
147+
std::string device;
148+
149+
// For NPU, we need to cache input_ids & position_ids to support chat-mode.
150+
std::vector<int64_t> cached_input_ids;
151+
std::vector<int64_t> cached_position_ids;
152+
};
153+
134154
} // namespace openvino_ep
135155
} // namespace onnxruntime

0 commit comments

Comments
 (0)