Skip to content

Commit 80cac84

Browse files
committed
Fix issue and revert states type
Signed-off-by: Bogdan Pereanu <[email protected]>
1 parent 2ecb526 commit 80cac84

File tree

4 files changed

+51
-53
lines changed

4 files changed

+51
-53
lines changed

src/plugins/intel_npu/src/backend/include/zero_infer_request.hpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,6 @@ class ZeroInferRequest final : public SyncInferRequest {
3535

3636
void get_result() override;
3737

38-
std::vector<ov::SoPtr<ov::IVariableState>> query_state() const override;
39-
40-
/**
41-
* @brief Initializes the tensor values corresponding to the state variables.
42-
* @details The inital values are usually all 0s.
43-
*/
44-
void initialize_states() override;
45-
4638
private:
4739
std::vector<ov::ProfilingInfo> get_profiling_info() const override;
4840

@@ -64,7 +56,9 @@ class ZeroInferRequest final : public SyncInferRequest {
6456
const bool isInput,
6557
const std::optional<std::size_t> batchSize = std::nullopt) const;
6658

67-
void add_state(const IODescriptor& descriptor, size_t tensorIndex) const;
59+
void add_state(const IODescriptor& descriptor,
60+
size_t tensorIndex,
61+
const std::shared_ptr<ZeroTensor>& zeroTensor) const;
6862

6963
void update_pipeline_if_memory_changed();
7064
void update_states_if_memory_changed();
@@ -82,8 +76,6 @@ class ZeroInferRequest final : public SyncInferRequest {
8276
mutable std::vector<std::vector<std::shared_ptr<ZeroTensor>>> _levelZeroInputTensors;
8377
mutable std::vector<std::shared_ptr<ZeroTensor>> _levelZeroOutputTensors;
8478

85-
mutable std::vector<ov::SoPtr<ZeroVariableState>> _variableStates;
86-
8779
std::unique_ptr<Pipeline> _pipeline;
8880

8981
bool _pipelineIsCreated = false;

src/plugins/intel_npu/src/backend/src/zero_infer_request.cpp

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -238,19 +238,22 @@ void ZeroInferRequest::create_pipeline() {
238238
_logger.debug("ZeroInferRequest::create_pipeline - set new tensors and reset variable state flag if memory updated "
239239
"before creating the pipeline");
240240
for (const auto& variableState : _variableStates) {
241-
if (variableState->tensor_was_updated()) {
241+
auto zeroState = std::dynamic_pointer_cast<ZeroVariableState>(variableState._ptr);
242+
OPENVINO_ASSERT(zeroState != nullptr, "State is not compatible with NPU plugin");
243+
244+
if (zeroState->tensor_was_updated()) {
242245
_logger.debug("ZeroInferRequest::create_pipeline - user state tensor should be updated");
243246

244-
get_user_input(variableState->get_tensor_index()) = variableState->get_state();
245-
_userOutputTensors.at(variableState->get_related_tensor_index()) = variableState->get_state();
246-
variableState->reset_tensor_updated_flag();
247+
get_user_input(zeroState->get_tensor_index()) = zeroState->get_state();
248+
_userOutputTensors.at(zeroState->get_related_tensor_index()) = zeroState->get_state();
249+
zeroState->reset_tensor_updated_flag();
247250

248-
if (variableState->zero_tensor_should_be_updated()) {
251+
if (zeroState->zero_tensor_should_be_updated()) {
249252
_logger.debug("ZeroInferRequest::create_pipeline - level zero state tensor should be updated");
250253

251-
get_level_zero_input(variableState->get_tensor_index()) = variableState->get_zero_state();
252-
_levelZeroOutputTensors.at(variableState->get_related_tensor_index()) = variableState->get_zero_state();
253-
variableState->reset_zero_tensor_updated_flag();
254+
get_level_zero_input(zeroState->get_tensor_index()) = zeroState->get_zero_state();
255+
_levelZeroOutputTensors.at(zeroState->get_related_tensor_index()) = zeroState->get_zero_state();
256+
zeroState->reset_zero_tensor_updated_flag();
254257
}
255258
}
256259
}
@@ -569,7 +572,7 @@ std::shared_ptr<ZeroTensor> ZeroInferRequest::allocate_tensor(const size_t index
569572
}
570573

571574
if (descriptor.isStateInput) {
572-
add_state(descriptor, index);
575+
add_state(descriptor, index, tensor);
573576
}
574577
} else if (_userOutputTensors.at(index) == nullptr) {
575578
_userOutputTensors.at(index) = tensor;
@@ -633,25 +636,27 @@ void ZeroInferRequest::update_pipeline_if_memory_changed() {
633636

634637
void ZeroInferRequest::update_states_if_memory_changed() {
635638
for (const auto& variableState : _variableStates) {
636-
if (variableState->tensor_was_updated()) {
637-
get_user_input(variableState->get_tensor_index()) = variableState->get_state();
638-
_userOutputTensors.at(variableState->get_related_tensor_index()) = variableState->get_state();
639-
variableState->reset_tensor_updated_flag();
639+
auto zeroState = std::dynamic_pointer_cast<ZeroVariableState>(variableState._ptr);
640+
OPENVINO_ASSERT(zeroState != nullptr, "State is not compatible with NPU plugin");
640641

641-
if (variableState->zero_tensor_should_be_updated()) {
642-
get_level_zero_input(variableState->get_tensor_index()) = variableState->get_zero_state();
643-
_levelZeroOutputTensors.at(variableState->get_related_tensor_index()) = variableState->get_zero_state();
644-
variableState->reset_zero_tensor_updated_flag();
642+
if (zeroState->tensor_was_updated()) {
643+
get_user_input(zeroState->get_tensor_index()) = zeroState->get_state();
644+
_userOutputTensors.at(zeroState->get_related_tensor_index()) = zeroState->get_state();
645+
zeroState->reset_tensor_updated_flag();
645646

646-
_pipeline->update_graph_arguments(
647-
_graphInputDescriptors.at(variableState->get_tensor_index()).idx,
648-
get_level_zero_input(variableState->get_tensor_index())->data(),
649-
get_level_zero_input(variableState->get_tensor_index())->get_byte_size());
647+
if (zeroState->zero_tensor_should_be_updated()) {
648+
get_level_zero_input(zeroState->get_tensor_index()) = zeroState->get_zero_state();
649+
_levelZeroOutputTensors.at(zeroState->get_related_tensor_index()) = zeroState->get_zero_state();
650+
zeroState->reset_zero_tensor_updated_flag();
651+
652+
_pipeline->update_graph_arguments(_graphInputDescriptors.at(zeroState->get_tensor_index()).idx,
653+
get_level_zero_input(zeroState->get_tensor_index())->data(),
654+
get_level_zero_input(zeroState->get_tensor_index())->get_byte_size());
650655

651656
_pipeline->update_graph_arguments(
652-
_graphOutputDescriptors.at(variableState->get_related_tensor_index()).idx,
653-
_levelZeroOutputTensors.at(variableState->get_related_tensor_index())->data(),
654-
_levelZeroOutputTensors.at(variableState->get_related_tensor_index())->get_byte_size());
657+
_graphOutputDescriptors.at(zeroState->get_related_tensor_index()).idx,
658+
_levelZeroOutputTensors.at(zeroState->get_related_tensor_index())->data(),
659+
_levelZeroOutputTensors.at(zeroState->get_related_tensor_index())->get_byte_size());
655660
}
656661
}
657662
}
@@ -842,21 +847,6 @@ void ZeroInferRequest::get_result() {
842847
_logger.debug("InferRequest::get_result finished");
843848
}
844849

845-
void ZeroInferRequest::initialize_states() {
846-
for (const auto& variableState : _variableStates) {
847-
variableState->reset();
848-
}
849-
}
850-
851-
std::vector<ov::SoPtr<ov::IVariableState>> ZeroInferRequest::query_state() const {
852-
std::vector<ov::SoPtr<ov::IVariableState>> result;
853-
result.reserve(_variableStates.size());
854-
for (const auto& state : _variableStates) {
855-
result.push_back(state); // Implicit upcast from SoPtr<ZeroVariableState> to SoPtr<IVariableState>
856-
}
857-
return result;
858-
}
859-
860850
void ZeroInferRequest::check_network_precision(const ov::element::Type_t precision) const {
861851
switch (precision) {
862852
case ov::element::Type_t::f32:
@@ -913,14 +903,16 @@ std::vector<ov::ProfilingInfo> ZeroInferRequest::get_profiling_info() const {
913903
return _pipeline->get_profiling_info();
914904
}
915905

916-
void ZeroInferRequest::add_state(const IODescriptor& descriptor, size_t tensorIndex) const {
906+
void ZeroInferRequest::add_state(const IODescriptor& descriptor,
907+
size_t tensorIndex,
908+
const std::shared_ptr<ZeroTensor>& zeroTensor) const {
917909
OPENVINO_ASSERT(descriptor.relatedDescriptorIndex.has_value(),
918910
"The link between state descriptors is missing, state name: ",
919911
descriptor.nameFromCompiler);
920912

921913
_variableStates.push_back(std::make_shared<ZeroVariableState>(_initStructs,
922914
descriptor.nameFromCompiler,
923-
get_level_zero_input(tensorIndex),
915+
zeroTensor,
924916
tensorIndex,
925917
descriptor.relatedDescriptorIndex.value(),
926918
_config));

src/plugins/intel_npu/src/common/include/intel_npu/common/sync_infer_request.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,13 @@ class SyncInferRequest : public ov::IInferRequest {
8484
*/
8585
virtual void get_result() = 0;
8686

87+
std::vector<ov::SoPtr<ov::IVariableState>> query_state() const override;
88+
8789
/**
8890
* @brief Initializes the tensor values corresponding to the state variables.
8991
* @details The inital values are usually all 0s.
9092
*/
91-
virtual void initialize_states() = 0;
93+
void initialize_states();
9294

9395
protected:
9496
/**
@@ -162,6 +164,8 @@ class SyncInferRequest : public ov::IInferRequest {
162164
mutable std::vector<std::vector<ov::SoPtr<ov::ITensor>>> _userInputTensors;
163165
mutable std::vector<ov::SoPtr<ov::ITensor>> _userOutputTensors;
164166

167+
mutable std::vector<ov::SoPtr<ov::IVariableState>> _variableStates;
168+
165169
/**
166170
* @see ov::ISyncInferRequest
167171
*/

src/plugins/intel_npu/src/common/src/sync_infer_request.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,16 @@ const std::shared_ptr<const ov::ICompiledModel>& SyncInferRequest::get_compiled_
102102
return _compiledModel;
103103
}
104104

105+
void SyncInferRequest::initialize_states() {
106+
for (const ov::SoPtr<ov::IVariableState>& variableState : _variableStates) {
107+
variableState->reset();
108+
}
109+
}
110+
111+
std::vector<ov::SoPtr<ov::IVariableState>> SyncInferRequest::query_state() const {
112+
return _variableStates;
113+
}
114+
105115
ov::SoPtr<ov::ITensor> SyncInferRequest::get_tensor(const ov::Output<const ov::Node>& port) const {
106116
auto foundPort = find_port(port);
107117
OPENVINO_ASSERT(foundPort.found(), "Cannot find tensor for port ", port);

0 commit comments

Comments
 (0)