@@ -196,13 +196,22 @@ NGraphEncapsulateOp::NGraphEncapsulateOp(OpKernelConstruction* ctx)
196196 BackendManager::SetConfig (ng_encap_impl.GetOpBackend (),
197197 additional_attribute_map);
198198
199- ng_encap_impl.SetExecCanCreateTensor (
199+ // For NNPI (even though executable can create tensor) use backend to create
200+ // tensor
201+ // Keep the executable_can_create_tensors check before the
202+ // backend_name!="NNPI"
203+ bool executable_create_tensor =
200204 BackendManager::GetBackend (ng_encap_impl.GetOpBackend ())
201- ->executable_can_create_tensors ());
205+ ->executable_can_create_tensors () &&
206+ (backend_name != " NNPI" );
207+ ng_encap_impl.SetExecCanCreateTensor (executable_create_tensor);
202208 NGRAPH_VLOG (5 ) << " Executable can "
203209 << (ng_encap_impl.GetExecCanCreateTensor () ? " " : " not" )
204210 << " create tensors" ;
205211
212+ const char * not_persistent_flag = std::getenv (" NGRAPH_TF_DISABLE_PERSISTENT" );
213+ m_use_persistent = (not_persistent_flag == nullptr );
214+
206215 event.Stop ();
207216 ngraph::Event::write_trace (event);
208217}
@@ -262,6 +271,7 @@ NGraphEncapsulateOp::~NGraphEncapsulateOp() {
262271 ng_encap_impl.ClearNgExecMap ();
263272 ng_encap_impl.ClearNgExecPipelinedTensorMap ();
264273 ng_encap_impl.ClearNgExecSerializedFunctionCache ();
274+ ng_encap_impl.ClearNgExecPersistentOutputCache ();
265275
266276 // Release the backend
267277 NGRAPH_VLOG (2 ) << " ~NGraphEncapsulateOp():: ReleaseBackend" ;
@@ -345,9 +355,20 @@ void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) {
345355 // Allocate tensors for the output results.
346356 vector<shared_ptr<ng::runtime::Tensor>> ng_outputs;
347357 std::vector<Tensor*> tf_output_tensors;
358+ std::vector<tensorflow::PersistentTensor> cached_persistent_output_tensors (
359+ ng_exec->get_results ().size ());
360+ bool present_in_cache = false ;
348361
349362 {
350363 NG_TRACE (" NGTF_Output_Alloc" , " " );
364+ if (m_use_persistent) {
365+ present_in_cache = ng_encap_impl.PersistentOutputsExist (ng_exec);
366+ if (present_in_cache) {
367+ OP_REQUIRES_OK (ctx, ng_encap_impl.GetPersistentTFOutputTensor (
368+ ng_exec, cached_persistent_output_tensors));
369+ }
370+ }
371+
351372 for (auto i = 0 ; i < ng_exec->get_results ().size (); i++) {
352373 auto ng_element = ng_exec->get_results ()[i];
353374 auto ng_shape = ng_element->get_shape ();
@@ -360,21 +381,40 @@ void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) {
360381 }
361382 TensorShape tf_shape (dims);
362383 Tensor* output_tensor = nullptr ;
363- OP_REQUIRES_OK (ctx, ctx->allocate_output (i, tf_shape, &output_tensor));
364- tf_output_tensors.push_back (output_tensor);
365384
366385 // Make sure the nGraph-inferred element type agrees with what TensorFlow
367386 // expected.
368387 ng::element::Type expected_elem_type;
388+ // TODO, we only need to do these checks once when the exec was
389+ // created/compiled, not again and again
390+
369391 OP_REQUIRES_OK (
370392 ctx, TFDataTypeToNGraphElementType (ctx->expected_output_dtype (i),
371393 &expected_elem_type));
372394 OP_REQUIRES (
373395 ctx, ng_element_type == expected_elem_type,
374396 errors::Internal (" Element type inferred by nGraph does not match "
375397 " the element type expected by TensorFlow" ));
376- }
377398
399+ if (m_use_persistent) {
400+ if (present_in_cache) {
401+ output_tensor = cached_persistent_output_tensors[i].AccessTensor (ctx);
402+ } else {
403+ // create a persistent tensor
404+ OP_REQUIRES_OK (
405+ ctx, ctx->allocate_persistent (
406+ ctx->expected_output_dtype (i), tf_shape,
407+ &cached_persistent_output_tensors[i], &output_tensor));
408+ }
409+ } else {
410+ OP_REQUIRES_OK (ctx, ctx->allocate_output (i, tf_shape, &output_tensor));
411+ }
412+ tf_output_tensors.push_back (output_tensor);
413+ }
414+ if (m_use_persistent && !present_in_cache) {
415+ OP_REQUIRES_OK (ctx, ng_encap_impl.RegisterPersistentOutputTensors (
416+ ng_exec, cached_persistent_output_tensors));
417+ }
378418 OP_REQUIRES_OK (ctx, ng_encap_impl.AllocateNGOutputTensors (
379419 tf_output_tensors, ng_exec, out_group_from_pipeline,
380420 op_backend, ng_outputs));
@@ -611,6 +651,16 @@ void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) {
611651 exp.what (), " \n " ));
612652 }
613653 }
654+
655+ if (m_use_persistent) {
656+ for (int out_idx = 0 ; out_idx < ng_exec->get_results ().size (); out_idx++) {
657+ OP_REQUIRES_OK (ctx, ng_encap_impl.GetPersistentTFOutputTensor (
658+ ng_exec, cached_persistent_output_tensors));
659+ auto out_tensor =
660+ cached_persistent_output_tensors[out_idx].AccessTensor (ctx);
661+ ctx->set_output (out_idx, *out_tensor);
662+ }
663+ }
614664} // end compute
615665
616666int NGraphEncapsulateImpl::s_instance_count = 0 ;
0 commit comments