diff --git a/tfjs-node/binding/tfjs_backend.cc b/tfjs-node/binding/tfjs_backend.cc index 4267cf5b96..1811e6adfb 100644 --- a/tfjs-node/binding/tfjs_backend.cc +++ b/tfjs-node/binding/tfjs_backend.cc @@ -27,6 +27,7 @@ #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" +#include "tensorflow/c/c_api_experimental.h" #include "tensorflow/core/platform/ctstring_internal.h" #include "tf_auto_tensor.h" #include "tfe_auto_op.h" @@ -688,6 +689,10 @@ TFJSBackend::TFJSBackend(napi_env env) device_name = std::string(TF_DeviceListName(device_list, i, tf_status.status)); ENSURE_TF_OK(env, tf_status); + } else if (strcmp(device_type, "XPU") == 0) { + device_name = + std::string(TF_DeviceListName(device_list, i, tf_status.status)); + ENSURE_TF_OK(env, tf_status); } } @@ -717,6 +722,17 @@ TFJSBackend::~TFJSBackend() { TFJSBackend *TFJSBackend::Create(napi_env env) { return new TFJSBackend(env); } +void TFJSBackend::LoadPluggableDeviceLibrary(napi_env env, napi_value lib_path_value) { + std::string lib_path; + napi_status nstatus = GetStringParam(env, lib_path_value, lib_path); + ENSURE_NAPI_OK(env, nstatus); + + TF_AutoStatus tf_status; + TF_LoadPluggableDeviceLibrary(lib_path.c_str(), tf_status.status); + + ENSURE_TF_OK(env, tf_status); +} + int32_t TFJSBackend::InsertHandle(TFE_TensorHandle *tfe_handle) { return tfe_handle_map_.insert(std::make_pair(next_tensor_id_++, tfe_handle)) .first->first; diff --git a/tfjs-node/binding/tfjs_backend.h b/tfjs-node/binding/tfjs_backend.h index fee71d0607..28dcbecdd1 100644 --- a/tfjs-node/binding/tfjs_backend.h +++ b/tfjs-node/binding/tfjs_backend.h @@ -35,6 +35,8 @@ class TFJSBackend { // fails, a nullptr is returned. static TFJSBackend *Create(napi_env env); + void LoadPluggableDeviceLibrary(napi_env env, napi_value lib_path_value); + // Creates a new Tensor with given shape and data and returns an ID that // references the new Tensor. // - shape_value (number[]) diff --git a/tfjs-node/binding/tfjs_binding.cc b/tfjs-node/binding/tfjs_binding.cc index ad00b79e78..b551c7d543 100644 --- a/tfjs-node/binding/tfjs_binding.cc +++ b/tfjs-node/binding/tfjs_binding.cc @@ -48,6 +48,33 @@ static void AssignIntProperty(napi_env env, napi_value exports, ENSURE_NAPI_OK(env, nstatus); } +static napi_value LoadPluggableDeviceLibrary(napi_env env, napi_callback_info info) { + napi_status nstatus; + + // Delete tensor takes 1 param: tensor ID; + size_t argc = 1; + napi_value args[1]; + napi_value js_this; + nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr); + ENSURE_NAPI_OK_RETVAL(env, nstatus, js_this); + + if (argc < 1) { + NAPI_THROW_ERROR(env, + "Invalid number of args passed to loadPluggableDeviceLibrary(). " + "Expecting 1 arg but got %d.", + argc); + return js_this; + } + + ENSURE_VALUE_IS_STRING_RETVAL(env, args[0], js_this); + + TFJSBackend *const backend = GetTFJSBackend(env); + if (!backend) return nullptr; + + backend->LoadPluggableDeviceLibrary(env, args[0]); + return js_this; +} + static napi_value CreateTensor(napi_env env, napi_callback_info info) { napi_status nstatus; @@ -301,6 +328,8 @@ static napi_value InitTFNodeJSBinding(napi_env env, napi_value exports) { // Set all export values list here. napi_property_descriptor exports_properties[] = { + {"loadPluggableDeviceLibrary", nullptr, LoadPluggableDeviceLibrary, nullptr, nullptr, nullptr, + napi_default, nullptr}, {"createTensor", nullptr, CreateTensor, nullptr, nullptr, nullptr, napi_default, nullptr}, {"deleteTensor", nullptr, DeleteTensor, nullptr, nullptr, nullptr, diff --git a/tfjs-node/src/index.ts b/tfjs-node/src/index.ts index c81bddd7b7..4910a39a7f 100644 --- a/tfjs-node/src/index.ts +++ b/tfjs-node/src/index.ts @@ -49,7 +49,7 @@ if (!fs.existsSync(bindingPath)) { `WINDOWS_TROUBLESHOOTING.md or file an issue.`); } // tslint:disable-next-line:no-require-imports -const bindings = require(bindingPath); +const bindings = require(bindingPath) as TFJSBinding; // Merge version and io namespaces. export const version = { @@ -70,7 +70,7 @@ const pjson = require('../package.json'); // Side effects for default initialization of Node backend. tf.registerBackend('tensorflow', () => { - return new NodeJSKernelBackend(bindings as TFJSBinding, pjson.name); + return new NodeJSKernelBackend(bindings, pjson.name); }, 3 /* priority */); const success = tf.setBackend('tensorflow'); @@ -84,3 +84,7 @@ tf.io.registerSaveRouter(nodeFileSystemRouter); // Register the ProgbarLogger for Model.fit() at verbosity level 1. tf.registerCallbackConstructor(1, ProgbarLogger); + +export function loadPluggableDeviceLibrary(libPath: string) { + bindings.loadPluggableDeviceLibrary(libPath); +} diff --git a/tfjs-node/src/tfjs_binding.ts b/tfjs-node/src/tfjs_binding.ts index 3352457b75..19eacb32c5 100644 --- a/tfjs-node/src/tfjs_binding.ts +++ b/tfjs-node/src/tfjs_binding.ts @@ -33,6 +33,8 @@ export interface TFJSBinding { TensorMetadata: typeof TensorMetadata; TFEOpAttr: typeof TFEOpAttr; + loadPluggableDeviceLibrary(libPath: string): void; + // Creates a tensor with the backend. createTensor( shape: number[], dtype: number,