forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpass.hpp
98 lines (79 loc) · 2.77 KB
/
pass.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "openvino/core/core_visibility.hpp"
#include "openvino/core/deprecated.hpp"
#include "openvino/core/enum_mask.hpp"
#include "openvino/core/model.hpp"
#include "openvino/core/node.hpp"
#include "openvino/pass/pass_config.hpp"
namespace ov {
namespace pass {
enum class PassProperty : uint32_t {
// Pass requires node shapes to be static
REQUIRE_STATIC_SHAPE = 0x1,
// Pass transformation will change the function's dynamic state
CHANGE_DYNAMIC_STATE = 1 << 1,
};
using PassPropertyMask = ov::EnumMask<PassProperty>;
/**
* @brief Base class for transformation passes
* @ingroup ov_pass_cpp_api
*/
class OPENVINO_API PassBase {
friend class Manager;
public:
PassBase();
virtual ~PassBase() = default;
/// Check if this pass has all the pass properties.
bool get_property(const PassPropertyMask& prop_mask) const;
void set_name(const std::string& name) {
m_name = name;
}
std::string get_name() const;
/// \brief Set callback for particular transformation type.
/// This method set global callback. For more details see PassConfig class
/// documentation.
/// \param callback lambda function that takes node and returns bool
void set_callback(const param_callback& callback);
/// \brief Set PassConfig for particular transformation instance
/// \param pass_config is a PassConfig shared_ptr
virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) {
m_pass_config = pass_config;
}
/// \brief Allows to access PassConfig shared instance
/// \return Shared instance of PassConfig class
std::shared_ptr<PassConfig> get_pass_config() {
return m_pass_config;
}
/// \brief Applies callback for given node. By default callback returns false.
/// \param node which will be used inside callback
/// \return result of callback execution for given node
bool transformation_callback(const std::shared_ptr<const Node>& node) {
return m_pass_config->get_callback(get_type_info())(node);
}
using type_info_t = DiscreteTypeInfo;
virtual const type_info_t& get_type_info() const = 0;
protected:
void set_property(const PassPropertyMask& prop, bool value);
private:
PassPropertyMask m_property;
std::string m_name;
std::shared_ptr<PassConfig> m_pass_config;
};
/**
* @brief Base class for Model passes
* @ingroup ov_pass_cpp_api
*/
class OPENVINO_API ModelPass : public PassBase {
public:
OPENVINO_RTTI("ov::pass::ModelPass");
~ModelPass() override;
virtual bool run_on_model(const std::shared_ptr<ov::Model>& m) = 0;
};
} // namespace pass
} // namespace ov