@@ -421,6 +421,10 @@ extern "C" const char *ClientGetPlatformName(PjRtClient *client) {
421
421
return cstr_from_string (client->platform_name ());
422
422
}
423
423
424
+ extern " C" const char *DeviceGetKind (PjRtDevice *device) {
425
+ return cstr_from_string (device->device_kind ());
426
+ }
427
+
424
428
// To keep in sync with JLAllocatorStats in src/XLA.jl
425
429
struct JLAllocatorStats {
426
430
int64_t num_allocs;
@@ -1258,36 +1262,6 @@ reactant_release_pjrtbuffer(HeldValue<std::shared_ptr<PjRtBuffer>> *buffer) {
1258
1262
delete buffer;
1259
1263
}
1260
1264
1261
- extern " C" ifrt::Client *
1262
- ifrt_pjrt_MakeClient (HeldValue<std::shared_ptr<PjRtClient>> *pjrt_client) {
1263
- xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj ()};
1264
- return MyValueOrThrow (xla::ifrt::PjRtClient::Create (options)).release ();
1265
- }
1266
-
1267
- extern " C" ifrt::Client *MakeCPUIfrtClient (uint8_t asynchronous, int node_id,
1268
- int num_nodes) {
1269
- return ifrt_pjrt_MakeClient (reactant_hold_pjrtclient (
1270
- MakeCPUClient (asynchronous, node_id, num_nodes)));
1271
- }
1272
-
1273
- extern " C" ifrt::Client *
1274
- MakeGPUIfrtClient (int node_id, int num_nodes, int *allowed_devices,
1275
- int num_allowed_devices, double memory_fraction,
1276
- bool preallocate, const char *platform_name,
1277
- const char **error) {
1278
- return ifrt_pjrt_MakeClient (reactant_hold_pjrtclient (
1279
- MakeGPUClient (node_id, num_nodes, allowed_devices, num_allowed_devices,
1280
- memory_fraction, preallocate, platform_name, error)));
1281
- }
1282
-
1283
- extern " C" ifrt::Client *MakeTPUIfrtClient (const char *tpu_path,
1284
- const char **error) {
1285
- return ifrt_pjrt_MakeClient (
1286
- reactant_hold_pjrtclient (MakeTPUClient (tpu_path, error)));
1287
- }
1288
-
1289
- extern " C" void ifrt_FreeClient (ifrt::Client *client) { delete client; }
1290
-
1291
1265
extern " C" xla::ifrt::LoadedExecutable *
1292
1266
ifrt_ClientCompile (ifrt::PjRtClient *client, MlirModule cmod, int64_t device_id,
1293
1267
bool is_sharded, const int64_t *mesh_ids,
@@ -1399,6 +1373,8 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
1399
1373
delete hlo_module;
1400
1374
}
1401
1375
1376
+ #pragma region IfRtClient
1377
+
1402
1378
// right now only making it available for TPU
1403
1379
// in the future, we would like this for CPU and GPU PjRt backends too
1404
1380
extern " C" ifrt::proxy::GrpcServer *
@@ -1469,6 +1445,79 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
1469
1445
.release ();
1470
1446
}
1471
1447
1448
+ extern " C" ifrt::Client *
1449
+ ifrt_pjrt_MakeClient (HeldValue<std::shared_ptr<PjRtClient>> *pjrt_client) {
1450
+ xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj ()};
1451
+ return MyValueOrThrow (xla::ifrt::PjRtClient::Create (options)).release ();
1452
+ }
1453
+
1454
+ extern " C" ifrt::Client *MakeCPUIfrtClient (uint8_t asynchronous, int node_id,
1455
+ int num_nodes) {
1456
+ return ifrt_pjrt_MakeClient (reactant_hold_pjrtclient (
1457
+ MakeCPUClient (asynchronous, node_id, num_nodes)));
1458
+ }
1459
+
1460
+ extern " C" ifrt::Client *
1461
+ MakeGPUIfrtClient (int node_id, int num_nodes, int *allowed_devices,
1462
+ int num_allowed_devices, double memory_fraction,
1463
+ bool preallocate, const char *platform_name,
1464
+ const char **error) {
1465
+ return ifrt_pjrt_MakeClient (reactant_hold_pjrtclient (
1466
+ MakeGPUClient (node_id, num_nodes, allowed_devices, num_allowed_devices,
1467
+ memory_fraction, preallocate, platform_name, error)));
1468
+ }
1469
+
1470
+ extern " C" ifrt::Client *MakeTPUIfrtClient (const char *tpu_path,
1471
+ const char **error) {
1472
+ return ifrt_pjrt_MakeClient (
1473
+ reactant_hold_pjrtclient (MakeTPUClient (tpu_path, error)));
1474
+ }
1475
+
1476
+ extern " C" void ifrt_FreeClient (ifrt::Client *client) { delete client; }
1477
+
1478
+ extern " C" int ifrt_ClientNumDevices (ifrt::Client *client) {
1479
+ return client->device_count ();
1480
+ }
1481
+
1482
+ extern " C" int ifrt_ClientNumAddressableDevices (ifrt::Client *client) {
1483
+ return client->addressable_device_count ();
1484
+ }
1485
+
1486
+ extern " C" int ifrt_ClientProcessIndex (ifrt::Client *client) {
1487
+ return client->process_index ();
1488
+ }
1489
+
1490
+ extern " C" const char *ifrt_ClientGetPlatformName (ifrt::Client *client) {
1491
+ return cstr_from_string (client->platform_name ());
1492
+ }
1493
+
1494
+ extern " C" ifrt::Device *ifrt_ClientGetDevice (ifrt::Client *client, int idx) {
1495
+ return MyValueOrThrow (client->LookupDevice (ifrt::DeviceId (idx)));
1496
+ }
1497
+
1498
+ extern " C" ifrt::Device *ifrt_ClientGetAddressableDevice (ifrt::Client *client,
1499
+ int idx) {
1500
+ return MyValueOrThrow (client->LookupAddressableDevice (idx));
1501
+ }
1502
+
1503
+ #pragma endregion
1504
+
1505
+ #pragma region IfRtDevice
1506
+
1507
+ extern " C" int64_t ifrt_DeviceGetGlobalDeviceId (ifrt::Device *device) {
1508
+ return device->Id ().value ();
1509
+ }
1510
+
1511
+ extern " C" const char *ifrt_DeviceGetKind (ifrt::Device *device) {
1512
+ return cstr_from_string (device->Kind ());
1513
+ }
1514
+
1515
+ extern " C" ifrt::Client *ifrt_DeviceToClient (ifrt::Device *device) {
1516
+ return device->client ();
1517
+ }
1518
+
1519
+ #pragma endregion
1520
+
1472
1521
#pragma region HloSharding
1473
1522
1474
1523
extern " C" void free_op_sharding (xla::OpSharding *op_sharding) {
0 commit comments