@@ -238,19 +238,22 @@ void ZeroInferRequest::create_pipeline() {
238238 _logger.debug (" ZeroInferRequest::create_pipeline - set new tensors and reset variable state flag if memory updated "
239239 " before creating the pipeline" );
240240 for (const auto & variableState : _variableStates) {
241- if (variableState->tensor_was_updated ()) {
241+ auto zeroState = std::dynamic_pointer_cast<ZeroVariableState>(variableState._ptr );
242+ OPENVINO_ASSERT (zeroState != nullptr , " State is not compatible with NPU plugin" );
243+
244+ if (zeroState->tensor_was_updated ()) {
242245 _logger.debug (" ZeroInferRequest::create_pipeline - user state tensor should be updated" );
243246
244- get_user_input (variableState ->get_tensor_index ()) = variableState ->get_state ();
245- _userOutputTensors.at (variableState ->get_related_tensor_index ()) = variableState ->get_state ();
246- variableState ->reset_tensor_updated_flag ();
247+ get_user_input (zeroState ->get_tensor_index ()) = zeroState ->get_state ();
248+ _userOutputTensors.at (zeroState ->get_related_tensor_index ()) = zeroState ->get_state ();
249+ zeroState ->reset_tensor_updated_flag ();
247250
248- if (variableState ->zero_tensor_should_be_updated ()) {
251+ if (zeroState ->zero_tensor_should_be_updated ()) {
249252 _logger.debug (" ZeroInferRequest::create_pipeline - level zero state tensor should be updated" );
250253
251- get_level_zero_input (variableState ->get_tensor_index ()) = variableState ->get_zero_state ();
252- _levelZeroOutputTensors.at (variableState ->get_related_tensor_index ()) = variableState ->get_zero_state ();
253- variableState ->reset_zero_tensor_updated_flag ();
254+ get_level_zero_input (zeroState ->get_tensor_index ()) = zeroState ->get_zero_state ();
255+ _levelZeroOutputTensors.at (zeroState ->get_related_tensor_index ()) = zeroState ->get_zero_state ();
256+ zeroState ->reset_zero_tensor_updated_flag ();
254257 }
255258 }
256259 }
@@ -569,7 +572,7 @@ std::shared_ptr<ZeroTensor> ZeroInferRequest::allocate_tensor(const size_t index
569572 }
570573
571574 if (descriptor.isStateInput ) {
572- add_state (descriptor, index);
575+ add_state (descriptor, index, tensor );
573576 }
574577 } else if (_userOutputTensors.at (index) == nullptr ) {
575578 _userOutputTensors.at (index) = tensor;
@@ -633,25 +636,27 @@ void ZeroInferRequest::update_pipeline_if_memory_changed() {
633636
634637void ZeroInferRequest::update_states_if_memory_changed () {
635638 for (const auto & variableState : _variableStates) {
636- if (variableState->tensor_was_updated ()) {
637- get_user_input (variableState->get_tensor_index ()) = variableState->get_state ();
638- _userOutputTensors.at (variableState->get_related_tensor_index ()) = variableState->get_state ();
639- variableState->reset_tensor_updated_flag ();
639+ auto zeroState = std::dynamic_pointer_cast<ZeroVariableState>(variableState._ptr );
640+ OPENVINO_ASSERT (zeroState != nullptr , " State is not compatible with NPU plugin" );
640641
641- if (variableState-> zero_tensor_should_be_updated ()) {
642- get_level_zero_input (variableState ->get_tensor_index ()) = variableState-> get_zero_state ();
643- _levelZeroOutputTensors .at (variableState ->get_related_tensor_index ()) = variableState-> get_zero_state ();
644- variableState-> reset_zero_tensor_updated_flag ();
642+ if (zeroState-> tensor_was_updated ()) {
643+ get_user_input (zeroState ->get_tensor_index ()) = zeroState-> get_state ();
644+ _userOutputTensors .at (zeroState ->get_related_tensor_index ()) = zeroState-> get_state ();
645+ zeroState-> reset_tensor_updated_flag ();
645646
646- _pipeline->update_graph_arguments (
647- _graphInputDescriptors.at (variableState->get_tensor_index ()).idx ,
648- get_level_zero_input (variableState->get_tensor_index ())->data (),
649- get_level_zero_input (variableState->get_tensor_index ())->get_byte_size ());
647+ if (zeroState->zero_tensor_should_be_updated ()) {
648+ get_level_zero_input (zeroState->get_tensor_index ()) = zeroState->get_zero_state ();
649+ _levelZeroOutputTensors.at (zeroState->get_related_tensor_index ()) = zeroState->get_zero_state ();
650+ zeroState->reset_zero_tensor_updated_flag ();
651+
652+ _pipeline->update_graph_arguments (_graphInputDescriptors.at (zeroState->get_tensor_index ()).idx ,
653+ get_level_zero_input (zeroState->get_tensor_index ())->data (),
654+ get_level_zero_input (zeroState->get_tensor_index ())->get_byte_size ());
650655
651656 _pipeline->update_graph_arguments (
652- _graphOutputDescriptors.at (variableState ->get_related_tensor_index ()).idx ,
653- _levelZeroOutputTensors.at (variableState ->get_related_tensor_index ())->data (),
654- _levelZeroOutputTensors.at (variableState ->get_related_tensor_index ())->get_byte_size ());
657+ _graphOutputDescriptors.at (zeroState ->get_related_tensor_index ()).idx ,
658+ _levelZeroOutputTensors.at (zeroState ->get_related_tensor_index ())->data (),
659+ _levelZeroOutputTensors.at (zeroState ->get_related_tensor_index ())->get_byte_size ());
655660 }
656661 }
657662 }
@@ -842,21 +847,6 @@ void ZeroInferRequest::get_result() {
842847 _logger.debug (" InferRequest::get_result finished" );
843848}
844849
845- void ZeroInferRequest::initialize_states () {
846- for (const auto & variableState : _variableStates) {
847- variableState->reset ();
848- }
849- }
850-
851- std::vector<ov::SoPtr<ov::IVariableState>> ZeroInferRequest::query_state () const {
852- std::vector<ov::SoPtr<ov::IVariableState>> result;
853- result.reserve (_variableStates.size ());
854- for (const auto & state : _variableStates) {
855- result.push_back (state); // Implicit upcast from SoPtr<ZeroVariableState> to SoPtr<IVariableState>
856- }
857- return result;
858- }
859-
860850void ZeroInferRequest::check_network_precision (const ov::element::Type_t precision) const {
861851 switch (precision) {
862852 case ov::element::Type_t::f32 :
@@ -913,14 +903,16 @@ std::vector<ov::ProfilingInfo> ZeroInferRequest::get_profiling_info() const {
913903 return _pipeline->get_profiling_info ();
914904}
915905
916- void ZeroInferRequest::add_state (const IODescriptor& descriptor, size_t tensorIndex) const {
906+ void ZeroInferRequest::add_state (const IODescriptor& descriptor,
907+ size_t tensorIndex,
908+ const std::shared_ptr<ZeroTensor>& zeroTensor) const {
917909 OPENVINO_ASSERT (descriptor.relatedDescriptorIndex .has_value (),
918910 " The link between state descriptors is missing, state name: " ,
919911 descriptor.nameFromCompiler );
920912
921913 _variableStates.push_back (std::make_shared<ZeroVariableState>(_initStructs,
922914 descriptor.nameFromCompiler ,
923- get_level_zero_input (tensorIndex) ,
915+ zeroTensor ,
924916 tensorIndex,
925917 descriptor.relatedDescriptorIndex .value (),
926918 _config));
0 commit comments