Skip to content

Commit baddde2

Browse files
Merge branch 'master' into r0.15
2 parents 9d5ca5f + c7c5ce3 commit baddde2

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

examples/cpp/infer_multiple_networks/inference_engine.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "inference_engine.h"
2323
#include "ngraph/event_tracing.hpp"
24+
#include "ngraph_backend_manager.h"
2425
#include "version.h"
2526

2627
#include "tensorflow/core/platform/macros.h"
@@ -72,11 +73,16 @@ Status InferenceEngine::Load(const string& network, const string& image_file,
7273

7374
// Preload the image is requested
7475
if (m_preload_images) {
76+
// Set the CPU as the backend before these ops
77+
string current_backend =
78+
tf::ngraph_bridge::BackendManager::GetCurrentlySetBackendName();
79+
tf::ngraph_bridge::BackendManager::SetBackendName("CPU");
7580
std::vector<tf::Tensor> resized_tensors;
7681
TF_CHECK_OK(ReadTensorFromImageFile(
7782
m_image_file, m_input_height, m_input_width, m_input_mean, m_input_std,
7883
m_use_NCHW, &resized_tensors));
7984
m_image_to_repeat = resized_tensors[0];
85+
tf::ngraph_bridge::BackendManager::SetBackendName(current_backend);
8086
}
8187
// Now compile the graph if needed
8288
// This would be useful to detect errors early. For a graph
@@ -119,12 +125,18 @@ void InferenceEngine::ThreadMain() {
119125
cout << "[" << m_name << "] " << step_count << ": Reading image\n";
120126
ngraph::Event read_event("Read", "", "");
121127

128+
string current_backend =
129+
tf::ngraph_bridge::BackendManager::GetCurrentlySetBackendName();
130+
tf::ngraph_bridge::BackendManager::SetBackendName("CPU");
131+
122132
std::vector<tf::Tensor> resized_tensors;
123133
TF_CHECK_OK(ReadTensorFromImageFile(
124134
m_image_file, m_input_height, m_input_width, m_input_mean,
125135
m_input_std, m_use_NCHW, &resized_tensors));
126136

127137
m_image_to_repeat = resized_tensors[0];
138+
tf::ngraph_bridge::BackendManager::SetBackendName(current_backend);
139+
128140
read_event.Stop();
129141
ngraph::Event::write_trace(read_event);
130142
}

0 commit comments

Comments
 (0)