Skip to content

Commit 7ea4c9a

Browse files
committed
Updated the eager API header files.
1 parent faacdef commit 7ea4c9a

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

jni/src/main/native/include/tensorflow/c/c_eager_api.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,16 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig(
6161
// Controls how to act when we try to run an operation on a given device but
6262
// some input tensors are not on that device.
6363
typedef enum TFE_ContextDevicePlacementPolicy {
64-
// The default: running operations with input tensors on the wrong device will
65-
// fail.
64+
// Running operations with input tensors on the wrong device will fail.
6665
TFE_DEVICE_PLACEMENT_EXPLICIT = 0,
6766
// Copy the tensor to the right device but log a warning.
6867
TFE_DEVICE_PLACEMENT_WARN = 1,
6968
// Silently copy the tensor, which has a performance cost since the
7069
// operation will be blocked till the copy completes.
7170
TFE_DEVICE_PLACEMENT_SILENT = 2,
71+
// Default placement policy which silently copies int32 tensors but not other
72+
// dtypes.
73+
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
7274
} TFE_ContextDevicePlacementPolicy;
7375

7476
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
@@ -93,6 +95,18 @@ TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
9395
// ops.
9496
TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx);
9597

98+
// Sets a thread-local device placement policy. After this call, other calls to
99+
// TFE_Execute in the same thread will use the device policy specified here
100+
// instead of the device policy used to construct the context. This has no
101+
// effect on the device policy used by other program threads.
102+
TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy(
103+
TFE_Context*, TFE_ContextDevicePlacementPolicy);
104+
105+
// Returns the device placement policy to be used by this context in the current
106+
// thread.
107+
TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy
108+
TFE_ContextGetDevicePlacementPolicy(TFE_Context*);
109+
96110
// A handle to a tensor on a device.
97111
//
98112
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,

jni/src/main/native/include/tensorflow/c/c_eager_api_internal.h

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include <cstddef>
2222
#include <memory>
2323
#include <string>
24+
#include <thread>
2425
#include <vector>
2526

2627
#include "tensorflow/c/c_api.h"
@@ -34,20 +35,34 @@ limitations under the License.
3435
#include "tensorflow/core/lib/gtl/stl_util.h"
3536
#include "tensorflow/core/platform/mutex.h"
3637
#include "tensorflow/core/platform/thread_annotations.h"
38+
#include "tensorflow/core/public/version.h"
3739

3840
struct TFE_ContextOptions {
3941
TF_SessionOptions session_options;
40-
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT};
42+
TFE_ContextDevicePlacementPolicy policy{
43+
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32};
4144
};
4245

4346
struct TFE_Context {
44-
explicit TFE_Context(TF_Session* s) : session(s) {}
45-
46-
TFE_ContextDevicePlacementPolicy policy;
47+
explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s)
48+
: policy(opts.policy),
49+
session(s),
50+
rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)),
51+
pflr(new tensorflow::ProcessFunctionLibraryRuntime(
52+
session->device_mgr, opts.session_options.options.env,
53+
TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {}
54+
55+
const TFE_ContextDevicePlacementPolicy policy;
56+
57+
// Note: we cannot use C++11 thread_local here as there is no concept of a
58+
// thread-local-object-local variable in C++11.
59+
tensorflow::mutex policy_map_mu;
60+
std::unordered_map<std::thread::id, TFE_ContextDevicePlacementPolicy>
61+
thread_local_policies GUARDED_BY(policy_map_mu);
4762

4863
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
49-
TF_Session* session;
50-
tensorflow::Rendezvous* rendezvous;
64+
TF_Session* const session;
65+
tensorflow::Rendezvous* const rendezvous;
5166

5267
tensorflow::mutex functions_mu;
5368
tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){
@@ -56,14 +71,14 @@ struct TFE_Context {
5671
// One FunctionLibraryRuntime per device.
5772
// func_libs[i] is the FunctionLibraryRuntime corresponding to
5873
// session->devices[i].
59-
std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
74+
const std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
6075

6176
tensorflow::mutex cache_mu;
6277
std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
6378
tensorflow::Fprint128Hasher>
6479
kernel_cache GUARDED_BY(cache_mu);
6580

66-
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) {
81+
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const {
6782
return pflr->GetFLR(d->name());
6883
}
6984

0 commit comments

Comments
 (0)