@@ -21,6 +21,7 @@ limitations under the License.
21
21
#include < cstddef>
22
22
#include < memory>
23
23
#include < string>
24
+ #include < thread>
24
25
#include < vector>
25
26
26
27
#include " tensorflow/c/c_api.h"
@@ -34,20 +35,34 @@ limitations under the License.
34
35
#include " tensorflow/core/lib/gtl/stl_util.h"
35
36
#include " tensorflow/core/platform/mutex.h"
36
37
#include " tensorflow/core/platform/thread_annotations.h"
38
+ #include " tensorflow/core/public/version.h"
37
39
38
40
struct TFE_ContextOptions {
39
41
TF_SessionOptions session_options;
40
- TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT};
42
+ TFE_ContextDevicePlacementPolicy policy{
43
+ TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32};
41
44
};
42
45
43
46
struct TFE_Context {
44
- explicit TFE_Context (TF_Session* s) : session(s) {}
45
-
46
- TFE_ContextDevicePlacementPolicy policy;
47
+ explicit TFE_Context (const TFE_ContextOptions& opts, TF_Session* s)
48
+ : policy(opts.policy),
49
+ session(s),
50
+ rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)),
51
+ pflr(new tensorflow::ProcessFunctionLibraryRuntime(
52
+ session->device_mgr, opts.session_options.options.env,
53
+ TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {}
54
+
55
+ const TFE_ContextDevicePlacementPolicy policy;
56
+
57
+ // Note: we cannot use C++11 thread_local here as there is no concept of a
58
+ // thread-local-object-local variable in C++11.
59
+ tensorflow::mutex policy_map_mu;
60
+ std::unordered_map<std::thread::id, TFE_ContextDevicePlacementPolicy>
61
+ thread_local_policies GUARDED_BY (policy_map_mu);
47
62
48
63
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
49
- TF_Session* session;
50
- tensorflow::Rendezvous* rendezvous;
64
+ TF_Session* const session;
65
+ tensorflow::Rendezvous* const rendezvous;
51
66
52
67
tensorflow::mutex functions_mu;
53
68
tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY (functions_mu){
@@ -56,14 +71,14 @@ struct TFE_Context {
56
71
// One FunctionLibraryRuntime per device.
57
72
// func_libs[i] is the FunctionLibraryRuntime corresponding to
58
73
// session->devices[i].
59
- std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
74
+ const std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
60
75
61
76
tensorflow::mutex cache_mu;
62
77
std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
63
78
tensorflow::Fprint128Hasher>
64
79
kernel_cache GUARDED_BY (cache_mu);
65
80
66
- tensorflow::FunctionLibraryRuntime* func_lib (tensorflow::Device* d) {
81
+ tensorflow::FunctionLibraryRuntime* func_lib (tensorflow::Device* d) const {
67
82
return pflr->GetFLR (d->name ());
68
83
}
69
84
0 commit comments