From 875836e304ce2721511db956017857a712bb5c9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Wed, 19 Mar 2025 14:58:00 +0000 Subject: [PATCH 1/2] [deps] Update to new Enzyme-JAX (which updates to new JAX and Co.) --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index a63b973d5..cd44ef768 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "7d5b22fc4d9e4d9fb28ba422cb3b6bb510ea0cf9" +ENZYMEXLA_COMMIT = "6116ac04785ad6c0e160b3614fe6cadc19d479c6" ENZYMEXLA_SHA256 = "" http_archive( From 7e5e710f5a8a598c57e5c352a578c85e7d9d6618 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Wed, 19 Mar 2025 17:49:34 +0100 Subject: [PATCH 2/2] Fix build --- deps/ReactantExtra/API.cpp | 6 ++++-- deps/ReactantExtra/WORKSPACE | 18 +++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 1c9183fa0..2bf2c3eea 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -1715,8 +1715,10 @@ extern "C" bool ifrt_DeviceIsAddressable(ifrt::Device *device) { return device->IsAddressable(); } -tsl::RCReference ifrt_CreateDeviceListFromDevices( - ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices) { +static xla::ifrt::RCReferenceWrapper +ifrt_CreateDeviceListFromDevices(ifrt::Client *client, + ifrt::Device **device_list, + int32_t num_devices) { absl::Span devices(device_list, num_devices); return client->MakeDeviceList(devices); } diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index cd44ef768..d5f0769ad 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -224,9 +224,6 @@ xla_workspace1() load("@xla//:workspace0.bzl", "xla_workspace0") xla_workspace0() -load("@jax//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") -flatbuffers() - load("@jax//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository") jax_python_wheel_repository( name = "jax_wheel", @@ -235,15 +232,18 @@ jax_python_wheel_repository( ) load( - "@tsl//third_party/py:python_wheel.bzl", + "@xla//third_party/py:python_wheel.bzl", "python_wheel_version_suffix_repository", ) python_wheel_version_suffix_repository( name = "jax_wheel_version_suffix", ) +load("@jax//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") +flatbuffers() + load( - "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "@xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", ) @@ -255,7 +255,7 @@ load( "CUDNN_REDISTRIBUTIONS", ) load( - "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "@xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", "cuda_redist_init_repositories", "cudnn_redist_init_repository", ) @@ -269,21 +269,21 @@ cudnn_redist_init_repository( ) load( - "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "@xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure", ) cuda_configure(name = "local_config_cuda") load( - "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "@xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", "nccl_redist_init_repository", ) nccl_redist_init_repository() load( - "@tsl//third_party/nccl/hermetic:nccl_configure.bzl", + "@xla//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure", )