|
21 | 21 |
|
22 | 22 | #include "inference_engine.h" |
23 | 23 | #include "ngraph/event_tracing.hpp" |
| 24 | +#include "ngraph_backend_manager.h" |
24 | 25 | #include "version.h" |
25 | 26 |
|
26 | 27 | #include "tensorflow/core/platform/macros.h" |
@@ -72,11 +73,16 @@ Status InferenceEngine::Load(const string& network, const string& image_file, |
72 | 73 |
|
73 | 74 | // Preload the image is requested |
74 | 75 | if (m_preload_images) { |
| 76 | + // Set the CPU as the backend before these ops |
| 77 | + string current_backend = |
| 78 | + tf::ngraph_bridge::BackendManager::GetCurrentlySetBackendName(); |
| 79 | + tf::ngraph_bridge::BackendManager::SetBackendName("CPU"); |
75 | 80 | std::vector<tf::Tensor> resized_tensors; |
76 | 81 | TF_CHECK_OK(ReadTensorFromImageFile( |
77 | 82 | m_image_file, m_input_height, m_input_width, m_input_mean, m_input_std, |
78 | 83 | m_use_NCHW, &resized_tensors)); |
79 | 84 | m_image_to_repeat = resized_tensors[0]; |
| 85 | + tf::ngraph_bridge::BackendManager::SetBackendName(current_backend); |
80 | 86 | } |
81 | 87 | // Now compile the graph if needed |
82 | 88 | // This would be useful to detect errors early. For a graph |
@@ -119,12 +125,18 @@ void InferenceEngine::ThreadMain() { |
119 | 125 | cout << "[" << m_name << "] " << step_count << ": Reading image\n"; |
120 | 126 | ngraph::Event read_event("Read", "", ""); |
121 | 127 |
|
| 128 | + string current_backend = |
| 129 | + tf::ngraph_bridge::BackendManager::GetCurrentlySetBackendName(); |
| 130 | + tf::ngraph_bridge::BackendManager::SetBackendName("CPU"); |
| 131 | + |
122 | 132 | std::vector<tf::Tensor> resized_tensors; |
123 | 133 | TF_CHECK_OK(ReadTensorFromImageFile( |
124 | 134 | m_image_file, m_input_height, m_input_width, m_input_mean, |
125 | 135 | m_input_std, m_use_NCHW, &resized_tensors)); |
126 | 136 |
|
127 | 137 | m_image_to_repeat = resized_tensors[0]; |
| 138 | + tf::ngraph_bridge::BackendManager::SetBackendName(current_backend); |
| 139 | + |
128 | 140 | read_event.Stop(); |
129 | 141 | ngraph::Event::write_trace(read_event); |
130 | 142 | } |
|
0 commit comments