Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Pluggable Devices #8524

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tfjs-node/binding/tfjs_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/core/platform/ctstring_internal.h"
#include "tf_auto_tensor.h"
#include "tfe_auto_op.h"
Expand Down Expand Up @@ -688,6 +689,10 @@ TFJSBackend::TFJSBackend(napi_env env)
device_name =
std::string(TF_DeviceListName(device_list, i, tf_status.status));
ENSURE_TF_OK(env, tf_status);
} else if (strcmp(device_type, "XPU") == 0) {
device_name =
std::string(TF_DeviceListName(device_list, i, tf_status.status));
ENSURE_TF_OK(env, tf_status);
}
}

Expand Down Expand Up @@ -717,6 +722,17 @@ TFJSBackend::~TFJSBackend() {

TFJSBackend *TFJSBackend::Create(napi_env env) { return new TFJSBackend(env); }

void TFJSBackend::LoadPluggableDeviceLibrary(napi_env env, napi_value lib_path_value) {
std::string lib_path;
napi_status nstatus = GetStringParam(env, lib_path_value, lib_path);
ENSURE_NAPI_OK(env, nstatus);

TF_AutoStatus tf_status;
TF_LoadPluggableDeviceLibrary(lib_path.c_str(), tf_status.status);

ENSURE_TF_OK(env, tf_status);
}

int32_t TFJSBackend::InsertHandle(TFE_TensorHandle *tfe_handle) {
return tfe_handle_map_.insert(std::make_pair(next_tensor_id_++, tfe_handle))
.first->first;
Expand Down
2 changes: 2 additions & 0 deletions tfjs-node/binding/tfjs_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class TFJSBackend {
// fails, a nullptr is returned.
static TFJSBackend *Create(napi_env env);

void LoadPluggableDeviceLibrary(napi_env env, napi_value lib_path_value);

// Creates a new Tensor with given shape and data and returns an ID that
// references the new Tensor.
// - shape_value (number[])
Expand Down
29 changes: 29 additions & 0 deletions tfjs-node/binding/tfjs_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,33 @@ static void AssignIntProperty(napi_env env, napi_value exports,
ENSURE_NAPI_OK(env, nstatus);
}

static napi_value LoadPluggableDeviceLibrary(napi_env env, napi_callback_info info) {
napi_status nstatus;

// Delete tensor takes 1 param: tensor ID;
size_t argc = 1;
napi_value args[1];
napi_value js_this;
nstatus = napi_get_cb_info(env, info, &argc, args, &js_this, nullptr);
ENSURE_NAPI_OK_RETVAL(env, nstatus, js_this);

if (argc < 1) {
NAPI_THROW_ERROR(env,
"Invalid number of args passed to loadPluggableDeviceLibrary(). "
"Expecting 1 arg but got %d.",
argc);
return js_this;
}

ENSURE_VALUE_IS_STRING_RETVAL(env, args[0], js_this);

TFJSBackend *const backend = GetTFJSBackend(env);
if (!backend) return nullptr;

backend->LoadPluggableDeviceLibrary(env, args[0]);
return js_this;
}

static napi_value CreateTensor(napi_env env, napi_callback_info info) {
napi_status nstatus;

Expand Down Expand Up @@ -301,6 +328,8 @@ static napi_value InitTFNodeJSBinding(napi_env env, napi_value exports) {

// Set all export values list here.
napi_property_descriptor exports_properties[] = {
{"loadPluggableDeviceLibrary", nullptr, LoadPluggableDeviceLibrary, nullptr, nullptr, nullptr,
napi_default, nullptr},
{"createTensor", nullptr, CreateTensor, nullptr, nullptr, nullptr,
napi_default, nullptr},
{"deleteTensor", nullptr, DeleteTensor, nullptr, nullptr, nullptr,
Expand Down
8 changes: 6 additions & 2 deletions tfjs-node/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ if (!fs.existsSync(bindingPath)) {
`WINDOWS_TROUBLESHOOTING.md or file an issue.`);
}
// tslint:disable-next-line:no-require-imports
const bindings = require(bindingPath);
const bindings = require(bindingPath) as TFJSBinding;

// Merge version and io namespaces.
export const version = {
Expand All @@ -70,7 +70,7 @@ const pjson = require('../package.json');

// Side effects for default initialization of Node backend.
tf.registerBackend('tensorflow', () => {
return new NodeJSKernelBackend(bindings as TFJSBinding, pjson.name);
return new NodeJSKernelBackend(bindings, pjson.name);
}, 3 /* priority */);

const success = tf.setBackend('tensorflow');
Expand All @@ -84,3 +84,7 @@ tf.io.registerSaveRouter(nodeFileSystemRouter);

// Register the ProgbarLogger for Model.fit() at verbosity level 1.
tf.registerCallbackConstructor(1, ProgbarLogger);

export function loadPluggableDeviceLibrary(libPath: string) {
bindings.loadPluggableDeviceLibrary(libPath);
}
2 changes: 2 additions & 0 deletions tfjs-node/src/tfjs_binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ export interface TFJSBinding {
TensorMetadata: typeof TensorMetadata;
TFEOpAttr: typeof TFEOpAttr;

loadPluggableDeviceLibrary(libPath: string): void;

// Creates a tensor with the backend.
createTensor(
shape: number[], dtype: number,
Expand Down