-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathAPI.cpp
2386 lines (2018 loc) · 85 KB
/
API.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/Pass/PassManager.h"
#include "Enzyme/MLIR/Dialect/Dialect.h"
#include "Enzyme/MLIR/Dialect/Ops.h"
#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Enzyme/MLIR/Passes/Passes.h"
#include "mlir/CAPI/Support.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "llvm/Support/TargetSelect.h"
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "absl/log/globals.h"
#include "absl/log/initialize.h"
#include "xla/mlir/utils/type_util.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "tsl/platform/init_main.h"
#include "tsl/profiler/lib/profiler_session.h"
#include "tsl/profiler/lib/traceme.h"
#include "xla/python/profiler_utils.h"
#include "xla/tsl/profiler/rpc/client/capture_profile.h"
#include "xla/tsl/profiler/rpc/profiler_server.h"
#include "xla/python/ifrt/hlo/hlo_program.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/Process.h"
#include "llvm/TargetParser/Host.h"
#include "llvm-c/TargetMachine.h"
// PJRT
#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/distributed/service.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_c_api_client.h"
#include "xla/pjrt/pjrt_executable.h"
// CPU collectives
#include "xla/backends/cpu/collectives/mpi_collectives.h"
#if defined(__linux__)
#include "gloo/transport/tcp/attr.h"
#include "gloo/transport/tcp/device.h"
#include "xla/backends/cpu/collectives/gloo_collectives.h"
#include "xla/backends/cpu/collectives/gloo_kv_store.h"
#elif defined(__APPLE__)
#include "gloo/transport/uv/device.h"
#include "xla/backends/cpu/collectives/gloo_collectives.h"
#include "xla/backends/cpu/collectives/gloo_kv_store.h"
#endif // defined(__linux__)
// shardy
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/transforms/passes.h"
#include "shardy/integrations/c/attributes.h"
#include "xla/pjrt/mlir_to_hlo.h"
#include "xla/service/spmd/shardy/stablehlo_round_trip/export_shardings.h"
#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_export.h"
#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h"
// IFRT
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/compiler.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/executable.h"
#include "xla/python/ifrt/hlo/hlo_program.h"
#include "xla/python/ifrt/host_callback.h"
#include "xla/python/ifrt/index.h"
#include "xla/python/ifrt/index_domain.h"
#include "xla/python/ifrt/ir/ifrt_ir_program.h"
#include "xla/python/ifrt/memory.h"
#include "xla/python/ifrt/shape.h"
#include "xla/python/ifrt/sharding.h"
#include "xla/python/ifrt/topology.h"
#include "xla/python/ifrt/tuple.h"
#include "xla/python/ifrt/value.h"
// IFRT - PJRT
#include "xla/python/pjrt_ifrt/pjrt_array.h"
#include "xla/python/pjrt_ifrt/pjrt_client.h"
#include "xla/python/pjrt_ifrt/pjrt_compiler.h"
#include "xla/python/pjrt_ifrt/pjrt_device.h"
#include "xla/python/pjrt_ifrt/pjrt_dtype.h"
#include "xla/python/pjrt_ifrt/pjrt_executable.h"
#include "xla/python/pjrt_ifrt/pjrt_host_callback.h"
#include "xla/python/pjrt_ifrt/pjrt_memory.h"
#include "xla/python/pjrt_ifrt/pjrt_topology.h"
#include "xla/python/pjrt_ifrt/pjrt_tuple.h"
#include "xla/python/pjrt_ifrt/xla_compiler.h"
#include "xla/python/pjrt_ifrt/xla_sharding.h"
// IFRT - Proxy (RPC)
#include "xla/python/ifrt_proxy/client/registry.h"
#include "xla/python/ifrt_proxy/server/grpc_server.h"
// Cost Analysis
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/hlo_cost_analysis.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/Support/ExtensibleRTTI.h"
using namespace mlir;
using namespace llvm;
using namespace xla;
namespace mlir {
namespace enzyme {
void registerRemoveTransformPass();
void registerGenerateApplyPatternsPass();
} // namespace enzyme
} // namespace mlir
namespace reactant {
template <typename T> struct unwrap_type {
typedef T type;
};
template <typename T> struct unwrap_type<std::shared_ptr<T>> {
typedef T type;
};
template <typename T> struct unwrap_type<tsl::RCReference<T>> {
typedef T type;
};
template <typename T> using unwrap_type_t = typename unwrap_type<T>::type;
template <typename T> struct HeldValue {
public:
HeldValue(T &obj) : holded(obj) {}
~HeldValue() = default;
unwrap_type_t<T> *ptr() const { return holded.get(); }
T obj() const { return holded; }
T value() const { return holded; }
unwrap_type_t<T> *operator->() const { return ptr(); }
private:
T holded;
};
template <typename T> HeldValue<T> *capture(T obj) {
return new HeldValue<T>(obj);
}
} // namespace reactant
using reactant::HeldValue;
using HeldPjRtClient = HeldValue<std::shared_ptr<xla::PjRtClient>>;
using HeldPjRtBuffer = HeldValue<std::shared_ptr<xla::PjRtBuffer>>;
using HeldIfrtArray = HeldValue<tsl::RCReference<xla::ifrt::Array>>;
using HeldHloModule = HeldValue<std::shared_ptr<xla::HloModule>>;
extern "C" void (*ReactantThrowError)(const char *) = nullptr;
// Utilities for `StatusOr`.
template <typename T> T MyValueOrThrow(absl::StatusOr<T> v) {
if (!v.ok()) {
ReactantThrowError(v.status().ToString().c_str());
}
return std::move(v).value();
}
extern "C" void ReactantHandleCuResult(uint32_t curesult) {
if (curesult != 0) {
std::string err = "Bad Cuda Result = " + std::to_string(curesult);
if (ReactantThrowError) {
ReactantThrowError(err.c_str());
}
}
}
// MLIR C-API extras
#pragma region MLIR Extra
extern "C" MlirAttribute mlirComplexAttrDoubleGet(MlirContext ctx,
MlirType type, double real,
double imag) {
return wrap(
complex::NumberAttr::get(cast<ComplexType>(unwrap(type)), real, imag));
}
extern "C" MlirAttribute mlirComplexAttrDoubleGetChecked(MlirLocation loc,
MlirType type,
double real,
double imag) {
return wrap(complex::NumberAttr::getChecked(
unwrap(loc), cast<ComplexType>(unwrap(type)), real, imag));
}
extern "C" bool mlirOperationInject(MlirContext ctx, MlirBlock block,
MlirStringRef code, MlirLocation location,
bool verify_after_parse) {
ParserConfig config(unwrap(ctx), verify_after_parse);
if (failed(parseSourceString(unwrap(code), unwrap(block), config)))
return false;
return true;
}
extern "C" MlirOperation mlirOperationParse(MlirContext ctx, MlirBlock block,
MlirStringRef code,
MlirLocation location,
bool verify_after_parse) {
ParserConfig config(unwrap(ctx), verify_after_parse);
if (failed(parseSourceString(unwrap(code), unwrap(block), config)))
return MlirOperation{nullptr};
return MlirOperation{
mlir::detail::constructContainerOpForParserIfNecessary<Operation *>(
unwrap(block), config.getContext(), unwrap(location))
.release()};
}
// TODO mlirComplexAttrGetnValue
// TODO extern "C" MlirTypeID mlirComplexAttrGetTypeID(void) { return
// wrap(complex::NumberAttr::getTypeID()); }
extern "C" void ReactantFuncSetResultAttr(MlirOperation op, intptr_t pos,
MlirStringRef name,
MlirAttribute attr) {
llvm::cast<mlir::FunctionOpInterface>(unwrap(op))
.setResultAttr(pos, unwrap(name), unwrap(attr));
}
extern "C" void ReactantFuncSetArgAttr(MlirOperation op, intptr_t pos,
MlirStringRef name, MlirAttribute attr) {
llvm::cast<mlir::FunctionOpInterface>(unwrap(op))
.setArgAttr(pos, unwrap(name), unwrap(attr));
}
#pragma endregion
// auxiliar functions
#pragma region utils
template <typename T> const char *cstr_from_string(T text) {
char *cstr = (char *)malloc(text.size() + 1);
memcpy(cstr, text.data(), text.size());
cstr[text.size()] = '\0';
return cstr;
}
template <typename T>
T *unwrap_absl_statusor(absl::StatusOr<T> status, char **error_msg) {
*error_msg = nullptr;
if (!status.ok()) {
auto str = status.message();
char *err = (char *)malloc(str.size() + 1);
memcpy(err, str.data(), str.size() + 1);
*error_msg = err;
return nullptr;
}
return status.value();
}
#pragma endregion
// int google::protobuf::io::CodedInputStream::default_recursion_limit_ = 100;
// int xla::_LayoutProto_default_instance_;
extern "C" void InitializeLogs() {
const char *binary = "julia";
int argc = 1;
char *argv[] = {(char *)binary};
char **argv2 = &argv[0];
tsl::port::InitMain(binary, &argc, &argv2);
LLVMInitializeX86Target();
LLVMInitializeX86TargetInfo();
LLVMInitializeX86TargetMC();
LLVMInitializeX86AsmPrinter();
LLVMInitializeX86AsmParser();
LLVMInitializeAArch64Target();
LLVMInitializeAArch64TargetInfo();
LLVMInitializeAArch64TargetMC();
LLVMInitializeAArch64AsmPrinter();
LLVMInitializeAArch64AsmParser();
}
extern "C" void SetLogLevel(int level) {
SetStderrThreshold((absl::LogSeverity)level);
// absl::SetGlobalVLogLevel(level);
}
extern "C" void SetModuleLogLevel(const char *module_pattern, int level) {
// absl::SetVLOGLevel(module_pattern, level);
}
extern "C" char *GetDefaultTargetTriple(void) {
return LLVMGetDefaultTargetTriple();
}
extern "C" MLIR_CAPI_EXPORTED MlirAttribute
enzymeActivityAttrGet(MlirContext ctx, int32_t val) {
return wrap(mlir::enzyme::ActivityAttr::get(unwrap(ctx),
(mlir::enzyme::Activity)val));
}
// Create profiler session and start profiling
extern "C" tsl::ProfilerSession *
CreateProfilerSession(uint32_t device_tracer_level,
uint32_t host_tracer_level) {
tensorflow::ProfileOptions options = tsl::ProfilerSession::DefaultOptions();
options.set_device_tracer_level(device_tracer_level);
options.set_host_tracer_level(host_tracer_level);
auto sess = tsl::ProfilerSession::Create(options);
return sess.release();
}
extern "C" void ProfilerSessionCollectData(tsl::ProfilerSession *session,
const char *path) {
tensorflow::profiler::XSpace xspace;
auto status = session->CollectData(&xspace);
if (!status.ok())
ReactantThrowError("cannot collect data for profiler");
tsl::profiler::ExportToTensorBoard(xspace, path,
/*also_export_trace_json*/ true);
}
extern "C" void ProfilerSessionDelete(tsl::ProfilerSession *session) {
delete session;
}
extern "C" int64_t ProfilerActivityStart(const char *name, int level) {
return tsl::profiler::TraceMe::ActivityStart(name, level);
}
extern "C" void ProfilerActivityEnd(int64_t id) {
tsl::profiler::TraceMe::ActivityEnd(id);
}
extern "C" tsl::profiler::ProfilerServer *ProfilerServerStart(int32_t port) {
auto server = new tsl::profiler::ProfilerServer();
server->StartProfilerServer(port);
return server;
}
extern "C" void ProfilerServerStop(tsl::profiler::ProfilerServer *server) {
delete server;
}
PjRtClient *MakeCPUClientInternal(
uint8_t asynchronous, int node_id,
std::optional<std::shared_ptr<xla::cpu::CpuCollectives>> collectives) {
CpuClientOptions options;
options.process_id = node_id;
options.asynchronous = asynchronous != 0;
if (collectives.has_value())
options.collectives = collectives.value();
auto client = MyValueOrThrow(GetTfrtCpuClient(options));
return client.release();
}
extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id) {
std::optional<std::shared_ptr<xla::cpu::CpuCollectives>> collectives;
return MakeCPUClientInternal(asynchronous, node_id, collectives);
}
// xla/python/xla.cc 390
extern "C" PjRtClient *
MakeGPUClient(int node_id, int num_nodes, int *allowed_devices,
int num_allowed_devices, double memory_fraction, bool preallocate,
const char *platform_name, const char **error,
void *distributed_runtime_client) {
GpuClientOptions options;
if (num_nodes > 1) {
if (distributed_runtime_client == nullptr) {
*error =
"`distributed_runtime_client` must be non-null if `num_nodes` > 1";
return nullptr;
}
auto typed_distributed_runtime_client = static_cast<
HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *>(
distributed_runtime_client);
options.kv_store = GetDistributedKeyValueStore(
typed_distributed_runtime_client->obj(), /*key_prefix=*/"");
}
// options.allocator_config =
options.allocator_config.preallocate = preallocate;
options.allocator_config.memory_fraction = memory_fraction;
options.node_id = node_id;
options.num_nodes = num_nodes;
options.allowed_devices =
allowed_devices ? std::set<int>(allowed_devices,
allowed_devices + num_allowed_devices)
: std::optional<std::set<int>>();
options.platform_name =
platform_name ? std::string(platform_name) : std::optional<std::string>();
// options.collectives = num_nodes;
auto clientErr = GetStreamExecutorGpuClient(options);
if (!clientErr.ok()) {
auto str = clientErr.status().message();
char *err = (char *)malloc(str.size() + 1);
memcpy(err, str.data(), str.size() + 1);
*error = err;
return nullptr;
} else {
auto client = std::move(clientErr).value();
return client.release();
}
}
const char *const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH";
extern "C" const PJRT_Api *LoadPjrtPlugin(const char *device_type,
const char *library_path,
const char **error) {
absl::StatusOr<const PJRT_Api *> pluginLoad =
pjrt::LoadPjrtPlugin(std::string(device_type), std::string(library_path));
if (!pluginLoad.ok()) {
auto str = pluginLoad.status().message();
char *err = (char *)malloc(str.size() + 1);
memcpy(err, str.data(), str.size() + 1);
*error = err;
return nullptr;
}
return pluginLoad.value();
}
extern "C" int InitializePjrtPlugin(const char *device_type,
const char **error) {
absl::Status tpu_status = pjrt::InitializePjrtPlugin(device_type);
if (!tpu_status.ok()) {
auto str = tpu_status.message();
char *err = (char *)malloc(str.size() + 1);
memcpy(err, str.data(), str.size() + 1);
*error = err;
return 1;
}
return 0;
}
extern "C" PjRtClient *GetCApiClient(const char *device_type) {
return xla::GetCApiClient(device_type).value().release();
}
extern "C" PjRtClient *MakeTPUClient(const char *tpu_path, const char **error) {
// Prefer $TPU_LIBRARY_PATH if set
std::string tpu_library_path;
if (auto path = llvm::sys::Process::GetEnv(kEnvTpuLibraryPath)) {
tpu_library_path = *path;
} else if (tpu_path) {
tpu_library_path = std::string(tpu_path);
} else {
*error = "Could not find TPU path";
return nullptr;
}
const PJRT_Api *pluginLoad =
LoadPjrtPlugin("tpu", tpu_library_path.c_str(), error);
if (pluginLoad == nullptr)
return nullptr;
auto tpu_status = InitializePjrtPlugin("tpu", error);
if (tpu_status)
return nullptr;
RegisterProfiler(pluginLoad);
return GetCApiClient("TPU");
}
extern "C" int ClientNumDevices(PjRtClient *client) {
return client->device_count();
}
extern "C" int ClientNumAddressableDevices(PjRtClient *client) {
return client->addressable_device_count();
}
extern "C" int ClientProcessIndex(PjRtClient *client) {
return client->process_index();
}
extern "C" PjRtDevice *ClientGetDevice(PjRtClient *client, int device_id) {
return MyValueOrThrow(client->LookupDevice(PjRtGlobalDeviceId(device_id)));
}
extern "C" PjRtDevice *ClientGetAddressableDevice(PjRtClient *client,
int device_id) {
return MyValueOrThrow(
client->LookupAddressableDevice(PjRtLocalDeviceId(device_id)));
}
extern "C" const char *ClientGetPlatformName(PjRtClient *client) {
return cstr_from_string(client->platform_name());
}
extern "C" const char *DeviceGetKind(PjRtDevice *device) {
return cstr_from_string(device->device_kind());
}
extern "C" void ClientGetDevices(PjRtClient *client, PjRtDevice **out_devices) {
auto span = client->devices();
for (int i = 0; i < span.size(); i++) {
out_devices[i] = span[i];
}
}
extern "C" void ClientGetAddressableDevices(PjRtClient *client,
PjRtDevice **out_devices) {
auto span = client->addressable_devices();
for (int i = 0; i < span.size(); i++) {
out_devices[i] = span[i];
}
}
// To keep in sync with JLAllocatorStats in src/XLA.jl
struct JLAllocatorStats {
int64_t num_allocs;
int64_t bytes_in_use;
int64_t peak_bytes_in_use;
int64_t largest_alloc_size;
int64_t bytes_limit;
int64_t bytes_reserved;
int64_t peak_bytes_reserved;
int64_t bytes_reservable_limit;
int64_t largest_free_block_bytes;
int64_t pool_bytes;
int64_t peak_pool_bytes;
};
extern "C" void PjRtDeviceGetAllocatorStats(PjRtDevice *device,
JLAllocatorStats *jlstats) {
auto stats = MyValueOrThrow(device->GetAllocatorStats());
int64_t optnull = std::numeric_limits<int64_t>::min();
jlstats->num_allocs = stats.num_allocs;
jlstats->bytes_in_use = stats.bytes_in_use;
jlstats->peak_bytes_in_use = stats.peak_bytes_in_use;
jlstats->largest_alloc_size = stats.largest_alloc_size;
jlstats->bytes_limit = stats.bytes_limit.value_or(optnull);
jlstats->bytes_reserved = stats.bytes_reserved;
jlstats->peak_bytes_reserved = stats.peak_bytes_reserved;
jlstats->bytes_reservable_limit =
stats.bytes_reservable_limit.value_or(optnull);
jlstats->largest_free_block_bytes = stats.largest_free_block_bytes;
jlstats->pool_bytes = stats.pool_bytes.value_or(optnull);
jlstats->peak_pool_bytes = stats.peak_pool_bytes.value_or(optnull);
}
extern "C" void ExecutableFree(xla::PjRtLoadedExecutable *exec) { delete exec; }
extern "C" PjRtDevice *BufferToDevice(PjRtBuffer *Buffer) {
return Buffer->device();
}
extern "C" PjRtClient *BufferToClient(PjRtBuffer *Buffer) {
return Buffer->client();
}
extern "C" const int64_t *BufferShape(PjRtBuffer *Buffer) {
return Buffer->dimensions().data();
}
extern "C" int64_t BufferNDimensions(PjRtBuffer *Buffer) {
return Buffer->dimensions().length();
}
extern "C" xla::PrimitiveType BufferPrimitiveType(PjRtBuffer *Buffer) {
return Buffer->element_type();
}
extern "C" void PjRtBufferFree(PjRtBuffer *Buffer) { delete Buffer; }
extern "C" PjRtClient *DeviceToClient(PjRtDevice *Device) {
return Device->client();
}
extern "C" PjRtClient *
PjRtLoadedExecutableGetClient(PjRtLoadedExecutable *exec) {
return exec->client();
}
// https://openxla.org/xla/shapes
// This minor-to-major dimension order of 0 up to N-1 is akin to column-major
// (at rank 2). Assuming a monotonic ordering of dimensions, another way we may
// refer to this layout in the code is simply "dim 0 is minor".
std::vector<int64_t> col_major(int64_t dim) {
std::vector<int64_t> minor_to_major;
for (int i = 0; i < dim; i++) {
minor_to_major.push_back(i); // dim-1-i);
// minor_to_major.push_back(dim-1-i);
}
return minor_to_major;
}
extern "C" void ReactantLLVMParseCommandLineOptions(int argc,
const char *const *argv,
const char *Overview) {
llvm::cl::ParseCommandLineOptions(argc, argv, StringRef(Overview),
&llvm::nulls());
}
std::vector<int64_t> row_major(int64_t dim) {
std::vector<int64_t> minor_to_major;
for (int i = 0; i < dim; i++) {
minor_to_major.push_back(dim - 1 - i);
}
return minor_to_major;
}
static void noop() {}
#ifdef REACTANT_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
extern "C" int32_t ReactantCudaDriverGetVersion() {
int32_t data;
ReactantHandleCuResult(cuDriverGetVersion(&data));
return data;
}
extern "C" int32_t ReactantHermeticCudaGetVersion() { return CUDA_VERSION; }
#else
extern "C" int32_t ReactantCudaDriverGetVersion() { return 0; }
extern "C" int32_t ReactantHermeticCudaGetVersion() { return 0; }
#endif
extern "C" void *UnsafeBufferPointer(PjRtBuffer *buffer) {
auto unsafe = MyValueOrThrow(buffer->client()->UnsafeBufferPointer(buffer));
return (void *)unsafe;
}
extern "C" PjRtBuffer *ArrayFromHostBuffer(PjRtClient *client, void *data,
uint64_t ptype, size_t dim,
int64_t *cshape,
PjRtDevice *device) {
auto primtype = (xla::PrimitiveType)ptype;
absl::Span<const int64_t> shape(cshape, dim);
PjRtClient::HostBufferSemantics semantics =
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall;
// xla::Layout layout(col_major(dim));
// auto buffer = xla::MyValueOrThrow(client->BufferFromHostBuffer(data,
// primtype, shape, /*byte_strides*/{}, semantics, /*ondone*/{}, device,
// &layout));
const xla::Layout *layout = nullptr;
auto buffer = MyValueOrThrow(client->BufferFromHostBuffer(
data, primtype, shape, /*byte_strides*/ {}, semantics, /*ondone*/ {},
*device->default_memory_space(), layout));
auto bres = buffer.release();
return bres;
}
extern "C" uint8_t BufferOnCPU(PjRtBuffer *buffer) { return buffer->IsOnCpu(); }
extern "C" PjRtBuffer *CopyBufferToDevice(PjRtBuffer *buffer,
PjRtDevice *dst_device) {
auto res = MyValueOrThrow(
buffer->CopyToMemorySpace(*dst_device->default_memory_space()));
return res.release();
}
extern "C" void BufferToHost(PjRtBuffer *buffer, void *data) {
Shape shape(MyValueOrThrow(buffer->HostShape()));
/// Grumpily the cpu copy code does not respect layout and does a raw copy
/// For now, we assume a non-julia row major ordering
/// If in the future it supports col_major we can swap to that.
*shape.mutable_layout() = xla::Layout(row_major(shape.dimensions_size()));
MutableBorrowingLiteral literal((const char *)data, shape);
auto status = buffer->ToLiteralSync(&literal);
if (!status.ok()) {
printf("error copying to host: %s\n", status.ToString().c_str());
}
}
extern "C" void FreeClient(PjRtClient *client) { delete client; }
extern "C" int64_t PjRtDeviceGetLocalDeviceId(PjRtDevice *device) {
return device->local_device_id().value();
}
extern "C" int64_t PjRtDeviceGetGlobalDeviceId(PjRtDevice *device) {
return device->global_device_id().value();
}
extern "C" int64_t PjRtDeviceGetLocalHardwareId(PjRtDevice *device) {
return device->local_hardware_id().value();
}
#include "xla/service/custom_call_target_registry.h"
extern "C" void RegisterCustomCallTarget(const char *name, void *address,
const char *platform) {
CustomCallTargetRegistry::Global()->Register(std::string(name), address,
std::string(platform));
}
#include "mlir/Target/LLVMIR/Import.h"
extern "C" MlirModule ConvertLLVMToMLIR(LLVMModuleRef lmod, MlirContext cctx) {
auto llvmModule = std::unique_ptr<llvm::Module>(unwrap(lmod));
mlir::MLIRContext &context = *unwrap(cctx);
auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context,
/*emitExpensiveWarnings*/ false,
/*dropDICompositeElements*/ false)
.release();
return wrap(res);
}
#include "llvm/IRReader/IRReader.h"
extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
LLVMContext Context;
SMDiagnostic Err;
auto llvmModule =
llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Err, Context);
if (!llvmModule) {
std::string err_str;
llvm::raw_string_ostream err_stream(err_str);
Err.print(/*ProgName=*/"LLVMToMLIR", err_stream);
err_stream.flush();
if (ReactantThrowError) {
llvm::errs() << lmod << "\n";
ReactantThrowError(err_str.c_str());
return wrap((mlir::ModuleOp) nullptr);
}
}
mlir::MLIRContext &context = *unwrap(cctx);
auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context,
/*emitExpensiveWarnings*/ false,
/*dropDICompositeElements*/ false)
.release();
if (!res) {
llvm::errs() << lmod << "\n";
ReactantThrowError("Could not translate LLVM IR to MLIR Module");
}
return wrap(res);
}
typedef PjRtFuture<> FutureType;
extern "C" void FreeFuture(FutureType *Future) { delete Future; }
extern "C" uint8_t FutureIsReady(FutureType *Future) {
return Future->IsReady();
}
extern "C" void FutureAwait(FutureType *Future) { Future->Await(); }
xla::CompileOptions GenerateCompileOptions(int64_t device_id, bool is_sharded,
const int64_t *mesh_ids,
int64_t num_mesh_ids,
const char *xla_gpu_cuda_data_dir,
bool use_shardy_partitioner) {
xla::CompileOptions options;
options.executable_build_options.mutable_debug_options()
->set_xla_gpu_cuda_data_dir(xla_gpu_cuda_data_dir);
if (is_sharded) {
assert(device_id < 0);
options.executable_build_options.set_num_replicas(1);
options.executable_build_options.set_num_partitions(num_mesh_ids);
options.executable_build_options.set_use_spmd_partitioning(true);
options.executable_build_options.set_use_shardy_partitioner(
use_shardy_partitioner);
// auto partitioning for GPUs is not available in open source version of XLA
// options.executable_build_options.set_use_auto_spmd_partitioning(true);
// std::vector<int64_t> mesh_shape_vec(mesh_shape, mesh_shape +
// num_mesh_shape);
// options.executable_build_options.set_auto_spmd_partitioning_mesh_shape(mesh_shape_vec);
// std::vector<int64_t> mesh_ids_vec(mesh_ids, mesh_ids + num_mesh_ids);
// options.executable_build_options.set_auto_spmd_partitioning_mesh_ids(mesh_ids_vec);
xla::DeviceAssignment device_assignment(1, num_mesh_ids);
for (int64_t i = 0; i < num_mesh_ids; ++i) {
int64_t mesh_id = mesh_ids[i];
assert(mesh_id >= 0);
device_assignment(0, i) = mesh_id;
}
options.executable_build_options.set_device_assignment(device_assignment);
options.executable_build_options
.set_allow_spmd_sharding_propagation_to_parameters({false});
options.executable_build_options
.set_allow_spmd_sharding_propagation_to_output({false});
} else {
assert(device_id >= 0);
options.executable_build_options.set_num_replicas(1);
options.executable_build_options.set_num_partitions(1);
options.executable_build_options.set_device_ordinal(device_id);
xla::DeviceAssignment device_assignment(1, 1);
device_assignment(0, 0) = device_id;
options.executable_build_options.set_device_assignment(device_assignment);
}
return options;
}
extern "C" xla::PjRtLoadedExecutable *
ClientCompile(PjRtClient *client, MlirModule cmod, int64_t device_id,
bool is_sharded, const int64_t *mesh_ids, int64_t num_mesh_ids,
const char *xla_gpu_cuda_data_dir, bool use_shardy_partitioner) {
CompileOptions options =
GenerateCompileOptions(device_id, is_sharded, mesh_ids, num_mesh_ids,
xla_gpu_cuda_data_dir, use_shardy_partitioner);
mlir::ModuleOp cmod_op = cast<ModuleOp>(*unwrap(cmod));
if (is_sharded && use_shardy_partitioner) {
// https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460
auto status = xla::ExportShardyForHloRoundTrip(cmod_op);
if (!status.ok()) {
ReactantThrowError(status.ToString().c_str());
}
}
auto exec_err = client->Compile(cmod_op, options);
if (!exec_err.ok()) {
std::string err_str;
llvm::raw_string_ostream err_stream(err_str);
err_stream << cmod_op << "\n";
err_stream << exec_err.status().ToString();
ReactantThrowError(err_stream.str().c_str());
}
return std::move(exec_err).value().release();
}
extern "C" void
PjRtLoadedExecutableGetOuputShardings(xla::PjRtLoadedExecutable *exec,
xla::OpSharding **op_shardings,
int32_t num_op_shardings) {
std::optional<std::vector<OpSharding>> shardings = exec->GetOutputShardings();
if (!shardings.has_value()) {
ReactantThrowError(
"No sharding found for the output of the loaded executable");
}
std::vector<xla::OpSharding> hlo_op_shardings = shardings.value();
if (num_op_shardings != hlo_op_shardings.size()) {
ReactantThrowError(("Expected " + std::to_string(num_op_shardings) +
" shardings, got " +
std::to_string(hlo_op_shardings.size()))
.c_str());
}
for (int32_t i = 0; i < num_op_shardings; i++) {
op_shardings[i] = new xla::OpSharding(hlo_op_shardings[i]);
}
}
extern "C" void
PjRtLoadedExecutableGetParameterShardings(xla::PjRtLoadedExecutable *exec,
xla::OpSharding **op_shardings,
int32_t num_op_shardings) {
std::optional<std::vector<OpSharding>> shardings =
exec->GetParameterShardings();
if (!shardings.has_value()) {
ReactantThrowError(
"No sharding found for the output of the loaded executable");
}
std::vector<xla::OpSharding> hlo_op_shardings = shardings.value();
if (num_op_shardings != hlo_op_shardings.size()) {
ReactantThrowError(("Expected " + std::to_string(num_op_shardings) +
" shardings, got " +
std::to_string(hlo_op_shardings.size()))
.c_str());
}
for (int32_t i = 0; i < num_op_shardings; i++) {
op_shardings[i] = new xla::OpSharding(hlo_op_shardings[i]);
}
}
extern "C" void XLAExecuteSharded(xla::PjRtLoadedExecutable *exec, int num_args,
PjRtBuffer **op_args, PjRtDevice *device,
uint8_t *is_arg_donatable, int num_results,
PjRtBuffer **op_results, uint8_t *futures,
FutureType **future_results) {
// Create a vector of PjRtBuffer* from the input array.
std::vector<PjRtBuffer *> argument_handles(op_args, op_args + num_args);
// Set up execution options.
ExecuteOptions options;
for (size_t i = 0; i < num_args; i++) {
if (!is_arg_donatable[i]) {
options.non_donatable_input_indices.insert(static_cast<int>(i));
}
}
options.untuple_result = true;
// Optional future to hold asynchronous execution results.
std::optional<PjRtFuture<>> returned_future;
auto results = MyValueOrThrow(exec->ExecuteSharded(argument_handles, device,
options, returned_future,
/*fill_future=*/true));
// Validate the number of results.
if (results.size() != num_results) {
ReactantThrowError(
("Error: results.size()=" + std::to_string(results.size()) +
" does not match num_results=" + std::to_string(num_results) + "\n")
.c_str());
}
// Handle futures if they are returned.
auto future_val = returned_future.has_value();
*futures = future_val;
if (future_val) {
for (size_t i = 0; i < num_results; i++) {
future_results[i] = new FutureType(*returned_future);
}
}
// Release the results into the output array.
for (size_t i = 0; i < num_results; i++) {
op_results[i] = results[i].release();
}
}
// This isn't exposed to julia, but leaving it here since it is very useful for
// debugging sharding (and generally for the execute workflow)
void PrintPjRtBuffer(PjRtBuffer *buffer) {
if (buffer) {
xla::Shape shape = MyValueOrThrow(buffer->HostShape());
auto dims = shape.dimensions();
auto nelems = std::accumulate(dims.begin(), dims.end(), 1,
std::multiplies<int64_t>());
std::vector<float> host_data(nelems);
BufferToHost(buffer, host_data.data());
for (int i = 0; i < nelems; ++i) {
std::cout << host_data[i] << " ";
}
std::cout << std::endl;
} else {
std::cout << " Buffer is nullptr" << std::endl;
}
return;
}
extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len,
PjRtBuffer **op_args, uint8_t *is_arg_donatable,
int num_results, PjRtBuffer **op_results,
uint8_t *futures, FutureType **future_results) {
xla::DeviceAssignment device_assignment = exec->device_assignment();
int num_devices = device_assignment.computation_count();
// Ensure argument_handles is structured as num_devices x num_args
std::vector<std::vector<PjRtBuffer *>> argument_handles(num_devices);
int num_args = op_args_len / num_devices;
// Distribute arguments across devices
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
argument_handles[device_idx].reserve(num_args);
for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) {
argument_handles[device_idx].push_back(
op_args[device_idx * num_args + arg_idx]);
}
}