Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ZeroInferRequest final : public SyncInferRequest {
const bool isInput,
const std::optional<std::size_t> batchSize = std::nullopt) const;

void add_state(const IODescriptor& descriptor, size_t tensorIndex) const override;
void add_state(const IODescriptor& descriptor, size_t tensorIndex) const;

void update_pipeline_if_memory_changed();
void update_states_if_memory_changed();
Expand Down
59 changes: 34 additions & 25 deletions src/plugins/intel_npu/src/backend/include/zero_variable_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "intel_npu/utils/logger/logger.hpp"
#include "intel_npu/utils/zero/zero_init.hpp"
#include "openvino/runtime/ivariable_state.hpp"
#include "zero_tensor.hpp"

namespace intel_npu {

Expand All @@ -20,57 +21,66 @@ class ZeroVariableState final : public ov::IVariableState {
public:
explicit ZeroVariableState(const std::shared_ptr<ZeroInitStructsHolder>& init_structs,
const std::string& name,
const ov::SoPtr<ov::ITensor>& tensor,
const std::shared_ptr<ZeroTensor>& zero_tensor,
size_t tensor_index,
size_t related_tensor_index,
const Config& config,
bool external_memory_standard_allocation_supported);
const Config& config);

void set_state(const ov::SoPtr<ov::ITensor>& new_state) override;

void reset() override;

ov::SoPtr<ov::ITensor> get_state() const override;

/**
* @brief Get input tensor index used internally for the state
* @brief Get user state to not change the state of the tensor through get_state()
*/
size_t get_tensor_index() const;
ov::SoPtr<ov::ITensor> get_user_state() const;

/**
* @brief Get output tensor index used internally for the state
* @details The related tensors are defined by state input, state output pairs.
* @brief Get internal level zero tensor. It can be different than the user tensor in case the user set a tensor
* that cannot be imported. Used by the InferenceRequest to update the arguments of the pipeline.
*/
size_t get_related_tensor_index() const;
std::shared_ptr<ZeroTensor> get_zero_state() const;

/**
* @brief Get acknowledgment if the tensor was updated
* @brief Get input tensor index used internally for the state
*/
bool tensor_was_updated() const;
size_t get_tensor_index() const;

/**
* @brief Reset tensor updated flag
* @brief Get output tensor index used internally for the state
* @details The related tensors are defined by state input, state output pairs.
*/
void reset_tensor_updated_flag();
size_t get_related_tensor_index() const;

/**
* @brief Get acknowledgment if the zero tensor was updated
* @details In case the memory was allocated in the same level zero context update the zero tensor
* @brief Get acknowledgment if state was updated
* @details Used to check if the state's internal user tensor was updated. Actions might need to be taken by the
* InferenceRequest in that case. This flag can be cleared using clear_state_update_pending(). An update to the user
* tensor might not trigger an update of the level zero tensor as well. zero_state_update_pending() should be used
* to check if the level zero tensor was also updated.
*/
bool zero_tensor_should_be_updated() const;
bool state_update_pending() const;

/**
* @brief Reset zero tensor updated flag
* @brief Reset state updated flag
* @details Must be used to reset the flag exposed through state_update_pending()
*/
void reset_zero_tensor_updated_flag();
void clear_state_update_pending();

/**
* @brief Get acknowledgment if the zero tensor can be imported
* @brief Get acknowledgment if the zero state was updated
* @details Used to signal that the state's internal zero tensor was also updated. Actions might need to be taken by
* the InferenceRequest in that case. This flag can be cleared using clear_zero_state_update_pending().
*/
bool zero_tensor_should_be_imported() const;
bool zero_state_update_pending() const;

/**
* @brief Reset zero tensor imported flag
* @brief Reset zero state updated flag
* @details Must be used to reset the flag exposed through zero_state_update_pending()
*/
void reset_tensor_imported_flag();
void clear_zero_state_update_pending();

~ZeroVariableState() override = default;

Expand All @@ -79,11 +89,10 @@ class ZeroVariableState final : public ov::IVariableState {
size_t _tensor_index;
size_t _related_tensor_index;

bool _tensor_updated = false;
bool _zero_tensor_updated = false;
bool _tensor_should_be_imported = false;
std::shared_ptr<ZeroTensor> _zero_state;

bool _external_memory_standard_allocation_supported = false;
bool _is_state_updated = false;
bool _is_zero_state_update_needed = false;

const Config _config;
Logger _logger;
Expand Down
Loading
Loading