Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IFRT] add c-bindings for "Held" PjRt classes #751

Merged
merged 9 commits into from
Feb 19, 2025
Merged
328 changes: 286 additions & 42 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
@@ -91,6 +91,7 @@
#include "xla/python/ifrt/topology.h"
#include "xla/python/ifrt/tuple.h"
#include "xla/python/ifrt/value.h"
#include "xla/python/ifrt/ir/ifrt_ir_program.h"

// IFRT - PJRT
#include "xla/python/pjrt_ifrt/pjrt_array.h"
@@ -1285,41 +1286,298 @@ template <typename T> HeldValue<T> *capture(T obj) {
} // namespace reactant

using reactant::HeldValue;
using HeldPjRtClient = HeldValue<std::shared_ptr<xla::PjRtClient>>;
using HeldPjRtBuffer = HeldValue<std::shared_ptr<xla::PjRtBuffer>>;
using HeldIfrtArray = HeldValue<tsl::RCReference<xla::ifrt::Array>>;

// deprecated
// extern "C" HeldPjRtClient * reactant_hold_pjrtclient(xla::PjRtClient *client) {
// return reactant::capture(std::shared_ptr<PjRtClient>(client));
// }

extern "C" HeldPjRtClient * pjrt_make_cpu_client_shared(
uint8_t asynchronous,
int node_id,
int num_nodes)
{
PjRtClient* client = MakeCPUClient(asynchronous, node_id, num_nodes);
return reactant::capture(std::shared_ptr<PjRtClient>(client));
}

extern "C" HeldValue<std::shared_ptr<PjRtClient>> *
reactant_hold_pjrtclient(xla::PjRtClient *client) {
extern "C" HeldPjRtClient* pjrt_make_gpu_client_shared(
int node_id,
int num_nodes,
int* allowed_devices,
int num_allowed_devices,
double memory_fraction,
bool preallocate,
const char* platform_name,
const char** error)
{
PjRtClient* client = MakeGPUClient(
node_id,
num_nodes,
allowed_devices,
num_allowed_devices,
memory_fraction,
preallocate,
platform_name,
error
);
return reactant::capture(std::shared_ptr<PjRtClient>(client));
}

extern "C" void
reactant_release_pjrtclient(HeldValue<std::shared_ptr<PjRtClient>> *client) {
extern "C" HeldPjRtClient* pjrt_make_tpu_client_shared(
const char* tpu_path,
const char** error
) {
PjRtClient* client = MakeTPUClient(tpu_path, error);
return reactant::capture(std::shared_ptr<PjRtClient>(client));
}

extern "C" void pjrt_client_dtor(HeldPjRtClient *client) {
delete client;
}

extern "C" HeldValue<std::shared_ptr<xla::PjRtBuffer>> *
reactant_hold_pjrtbuffer(xla::PjRtBuffer *buffer) {
return reactant::capture(std::shared_ptr<xla::PjRtBuffer>(buffer));
extern "C" int pjrt_client_num_devices(HeldPjRtClient* client) {
return client->ptr()->device_count();
}

extern "C" void
reactant_release_pjrtbuffer(HeldValue<std::shared_ptr<PjRtBuffer>> *buffer) {
extern "C" int pjrt_client_num_addressable_devices(
HeldPjRtClient* client
) {
return client->ptr()->addressable_device_count();
}

extern "C" int pjrt_client_pid(HeldPjRtClient* client) {
return client->ptr()->process_index();
}

extern "C" PjRtDevice* pjrt_client_get_device(
HeldPjRtClient* client,
int device_id
) {
return ClientGetDevice(client->ptr(), device_id);
}

extern "C" PjRtDevice* pjrt_client_get_addressable_device(
HeldPjRtClient* client,
int device_id
) {
return ClientGetAddressableDevice(client->ptr(), device_id);
}

extern "C" const char* pjrt_client_platform_name(
HeldPjRtClient* client
) {
return ClientGetPlatformName(client->ptr());
}

// deprecated
// extern "C" HeldValue<std::shared_ptr<xla::PjRtBuffer>> *
// reactant_hold_pjrtbuffer(xla::PjRtBuffer *buffer) {
// return reactant::capture(std::shared_ptr<xla::PjRtBuffer>(buffer));
// }

extern "C" HeldPjRtBuffer*
pjrt_buffer_from_host(
HeldPjRtClient* client,
void* data,
uint64_t ptype,
size_t dim,
int64_t* cshape,
PjRtDevice* device
) {
PjRtBuffer* buffer = ArrayFromHostBuffer(
client->ptr(),
data,
ptype,
dim,
cshape,
device
);
return reactant::capture(std::shared_ptr<PjRtBuffer>(buffer));
}

extern "C" void pjrt_buffer_dtor(HeldPjRtBuffer *buffer) {
delete buffer;
}

extern "C" ifrt::Client *
ifrt_pjrt_MakeClient(HeldValue<std::shared_ptr<PjRtClient>> *pjrt_client) {
extern "C" void* pjrt_buffer_unsafe_buffer_pointer(
HeldPjRtBuffer* buffer)
{
return UnsafeBufferPointer(buffer->ptr());
}

extern "C" bool pjrt_buffer_is_on_cpu(HeldPjRtBuffer* buffer) {
return buffer->ptr()->IsOnCpu();
}

extern "C" HeldPjRtBuffer* pjrt_buffer_copy_to_device(
HeldPjRtBuffer* buffer,
PjRtDevice* dst_device)
{
PjRtBuffer* ret = CopyBufferToDevice(buffer->ptr(), dst_device);
return reactant::capture(std::shared_ptr<PjRtBuffer>(ret));
}

extern "C" void pjrt_buffer_to_host(HeldPjRtBuffer* buffer, void* data)
{
BufferToHost(buffer->ptr(), data);
}

extern "C" void pjrt_buffer_print(HeldPjRtBuffer* buffer) {
PrintPjRtBuffer(buffer->ptr());
}

extern "C" PjRtDevice* pjrt_buffer_get_device(HeldPjRtBuffer* buffer) {
return buffer->ptr()->device();
}

extern "C" HeldPjRtClient* pjrt_buffer_get_client(
HeldPjRtBuffer* buffer
) {
return reactant::capture(
std::shared_ptr<PjRtClient>(buffer->ptr()->client())
);
}

extern "C" ifrt::Client* ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client)
{
xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()};
return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release();
}

extern "C" void ifrt_FreeClient(ifrt::Client *client) { delete client; }
extern "C" void ifrt_client_dtor(ifrt::Client *client) { delete client; }

// generic version, but IFRT-PjRt backend only supports SingleDeviceSharding
// and FullyReplicated. use `ifrt_pjrt_array_create` if using IFRT-PjRt.
extern "C" HeldIfrtArray* ifrt_client_make_array_from_host_buffer(
ifrt::Client* client,
void* data,
int dtype_kind, // int
int ndims,
const int64_t* c_shape,
HeldValue<std::shared_ptr<const ifrt::Sharding>>* sharding,
int c_semantics
) {
auto dtype = ifrt::DType(static_cast<ifrt::DType::Kind>(dtype_kind));
auto shape = ifrt::Shape(absl::Span<const int64_t>(c_shape, ndims));
return reactant::capture(MyValueOrThrow(client->MakeArrayFromHostBuffer(
data,
dtype,
shape,
std::nullopt, // byte_strides
sharding->obj(),
static_cast<ifrt::Client::HostBufferSemantics>(c_semantics),
[]{} // on_done_with_host_buffer
)));
}

extern "C" HeldIfrtArray* ifrt_client_make_single_shard_array_from_host_buffer(
ifrt::Client* client,
void* data,
int dtype_kind, // int
int ndims,
const int64_t* c_shape,
int c_semantics,
ifrt::Device* device,
const char* mem_kind
) {
auto memory_kind = ifrt::MemoryKind(std::string(mem_kind));
auto sharding = reactant::capture(std::shared_ptr<const ifrt::Sharding>(
ifrt::SingleDeviceSharding::Create(device, memory_kind).release()
));
return ifrt_client_make_array_from_host_buffer(
client,
data,
dtype_kind,
ndims,
c_shape,
sharding,
c_semantics
);
}

// all arrays are assumed to have same DType
extern "C" HeldIfrtArray* ifrt_client_assemble_array_from_single_shards(
ifrt::Client* client,
int ndims,
const int64_t* c_shape,
HeldValue<std::shared_ptr<const ifrt::Sharding>>* sharding,
int narrays,
HeldIfrtArray** c_arrays,
int c_semantics
) {
auto shape = ifrt::Shape(absl::Span<const int64_t>(c_shape, ndims));
std::vector<tsl::RCReference<ifrt::Array>> arrays;
for (int i = 0; i < narrays; i++) {
arrays.emplace_back(c_arrays[i]->obj());
}
auto semantics = static_cast<ifrt::ArrayCopySemantics>(c_semantics);
return reactant::capture(MyValueOrThrow(
client->AssembleArrayFromSingleDeviceArrays(
shape,
sharding->obj(),
static_cast<absl::Span<tsl::RCReference<xla::ifrt::Array>>>(arrays),
semantics
)
));
}

// we should deprecate this because is IFRT-PjRt specific
// try use `ifrt_client_make_single_shard_array_from_host_buffer` instead
extern "C" HeldIfrtArray* ifrt_pjrt_array_create(
ifrt::PjRtClient *client,
HeldValue<std::shared_ptr<xla::PjRtBuffer>> *buffer
) {
return reactant::capture(tsl::RCReference<ifrt::Array>(
MyValueOrThrow(xla::ifrt::PjRtArray::Create(client, buffer->obj()))));
}

extern "C" xla::ifrt::LoadedExecutable *
ifrt_ClientCompile(ifrt::PjRtClient *client, MlirModule cmod, int64_t device_id,
bool is_sharded, const int64_t *mesh_ids,
int64_t num_mesh_ids, const char *xla_gpu_cuda_data_dir) {
CompileOptions options = GenerateCompileOptions(
device_id, is_sharded, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir);
// TODO how do we compile for other backends?
// extern "C" xla::ifrt::LoadedExecutable* ifrt_pjrt_compile(
// ifrt::PjRtClient *client, MlirModule cmod, int64_t device_id,
// bool is_sharded, const int64_t *mesh_ids,
// int64_t num_mesh_ids, const char *xla_gpu_cuda_data_dir
// ) {
// CompileOptions options = GenerateCompileOptions(
// device_id, is_sharded, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir);

// mlir::ModuleOp cmod_op = cast<ModuleOp>(*unwrap(cmod));
// if (is_sharded) {
// // https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460
// auto status = xla::ExportShardyForHloRoundTrip(cmod_op);
// if (!status.ok()) {
// ReactantThrowError(status.ToString().c_str());
// }
// }

// // TODO can't create LoadedExecutable from mlir::ModuleOp on IFRT-proxy
// // backend
// auto exec = MyValueOrThrow(xla::ifrt::PjRtLoadedExecutable::Create(
// client, cmod_op, options,
// std::vector<tsl::RCReference<xla::ifrt::LoadedHostCallback>>()));
// return exec.release();
// }

// we might me interested in the `Compiler::Compile` method variant that
// accepts `Topology`
extern "C" xla::ifrt::LoadedExecutable* ifrt_compile(
ifrt::Client *client, MlirModule cmod, int64_t device_id,
bool is_sharded, const int64_t *mesh_ids,
int64_t num_mesh_ids, const char *xla_gpu_cuda_data_dir
) {
// TODO we need a `xla::ifrt::CompileOptions` but this is `xla::CompileOptions`
auto options = std::make_unique<CompileOptions>(
GenerateCompileOptions(
device_id,
is_sharded,
mesh_ids,
num_mesh_ids,
xla_gpu_cuda_data_dir
)
);
Copy link
Collaborator Author

@mofeing mofeing Feb 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is crashing because compiler->Compile is expecting a ifrt::Compiler::CompileOptions and not xla::CompileOptions

@avik-pal would you mind taking care of this? i think we need to use this https://github.com/openxla/xla/blob/81335eabdc55f0a2ff02bde7f79da7adda7af2c9/xla/python/ifrt/ir/ifrt_ir_program.h#L102-L109


mlir::ModuleOp cmod_op = cast<ModuleOp>(*unwrap(cmod));
if (is_sharded) {
@@ -1330,40 +1588,26 @@ ifrt_ClientCompile(ifrt::PjRtClient *client, MlirModule cmod, int64_t device_id,
}
}

// TODO can't create LoadedExecutable from mlir::ModuleOp on IFRT-proxy
// backend
auto exec = MyValueOrThrow(xla::ifrt::PjRtLoadedExecutable::Create(
client, cmod_op, options,
std::vector<tsl::RCReference<xla::ifrt::LoadedHostCallback>>()));
return exec.release();
auto program = std::make_unique<xla::ifrt::Program>(xla::ifrt::IfrtIRProgram(cmod_op));
auto compiler = client->GetDefaultCompiler();

return MyValueOrThrow(compiler->Compile(program, options)).release();
}

extern "C" void
ifrt_pjrt_FreeLoadedExecutable(xla::ifrt::PjRtLoadedExecutable *exec) {
ifrt_pjrt_loaded_executable_dtor(xla::ifrt::PjRtLoadedExecutable *exec) {
delete exec;
}

// TODO replace with `Client::MakeArrayFromHostBuffer` and generalize to
// `ifrt::Client`
extern "C" HeldValue<tsl::RCReference<xla::ifrt::Array>> *
ifrt_pjrt_ArrayFromHostBuffer(
ifrt::PjRtClient *client,
HeldValue<std::shared_ptr<xla::PjRtBuffer>> *buffer) {
return reactant::capture(tsl::RCReference<ifrt::Array>(
MyValueOrThrow(xla::ifrt::PjRtArray::Create(client, buffer->obj()))));
}

extern "C" void reactant_release_ifrt_array(
HeldValue<tsl::RCReference<xla::ifrt::Array>> *array) {
delete array;
}
extern "C" void ifrt_array_dtor(HeldIfrtArray *array) { delete array; }

extern "C" void
ifrt_Execute(ifrt::LoadedExecutable *exec, int num_args,
ifrt_loaded_executable_execute(ifrt::LoadedExecutable *exec, int num_args,
HeldValue<tsl::RCReference<ifrt::Array>> **op_args,
uint8_t *is_arg_donatable, int num_results,
HeldValue<tsl::RCReference<ifrt::Array>> **op_results,
uint8_t *futures, FutureType **status) {
uint8_t *futures, FutureType **status)
{
std::vector<tsl::RCReference<xla::ifrt::Array>> args;
for (int i = 0; i < num_args; i++) {
args.emplace_back(op_args[i]->obj());
@@ -1399,7 +1643,7 @@ ifrt_Execute(ifrt::LoadedExecutable *exec, int num_args,

// in principle, use ArrayCopySemantics::kAlwaysCopy (=0)
extern "C" FutureType *
ifrt_CopyArrayToHostBuffer(HeldValue<tsl::RCReference<xla::ifrt::Array>> *array,
ifrt_CopyArrayToHostBuffer(HeldIfrtArray *array,
void *data, ifrt::ArrayCopySemantics semantics) {
return new FutureType(
(*array)->CopyToHostBuffer(data, std::nullopt, semantics));
1 change: 1 addition & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
@@ -550,6 +550,7 @@ cc_library(
# "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options",
# "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
"@xla//xla/python/ifrt/hlo:hlo_program",
"@xla//xla/python/ifrt/ir:ifrt_ir_program",
"@xla//xla/ffi:call_frame",
"@com_google_protobuf//:protobuf",

Loading