Skip to content

Commit ae11523

Browse files
giordanoPangoraw
andauthored
[deps] Update to new Enzyme-JAX (which updates to new JAX and Co.) (#961)
* [deps] Update to new Enzyme-JAX (which updates to new JAX and Co.) * Fix build --------- Co-authored-by: Paul Berg <[email protected]>
1 parent f565277 commit ae11523

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

deps/ReactantExtra/API.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -1715,8 +1715,10 @@ extern "C" bool ifrt_DeviceIsAddressable(ifrt::Device *device) {
17151715
return device->IsAddressable();
17161716
}
17171717

1718-
tsl::RCReference<ifrt::DeviceList> ifrt_CreateDeviceListFromDevices(
1719-
ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices) {
1718+
static xla::ifrt::RCReferenceWrapper<ifrt::DeviceList>
1719+
ifrt_CreateDeviceListFromDevices(ifrt::Client *client,
1720+
ifrt::Device **device_list,
1721+
int32_t num_devices) {
17201722
absl::Span<ifrt::Device *const> devices(device_list, num_devices);
17211723
return client->MakeDeviceList(devices);
17221724
}

deps/ReactantExtra/WORKSPACE

+10-10
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "7d5b22fc4d9e4d9fb28ba422cb3b6bb510ea0cf9"
12+
ENZYMEXLA_COMMIT = "6116ac04785ad6c0e160b3614fe6cadc19d479c6"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(
@@ -224,9 +224,6 @@ xla_workspace1()
224224
load("@xla//:workspace0.bzl", "xla_workspace0")
225225
xla_workspace0()
226226

227-
load("@jax//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
228-
flatbuffers()
229-
230227
load("@jax//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository")
231228
jax_python_wheel_repository(
232229
name = "jax_wheel",
@@ -235,15 +232,18 @@ jax_python_wheel_repository(
235232
)
236233

237234
load(
238-
"@tsl//third_party/py:python_wheel.bzl",
235+
"@xla//third_party/py:python_wheel.bzl",
239236
"python_wheel_version_suffix_repository",
240237
)
241238
python_wheel_version_suffix_repository(
242239
name = "jax_wheel_version_suffix",
243240
)
244241

242+
load("@jax//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
243+
flatbuffers()
244+
245245
load(
246-
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
246+
"@xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
247247
"cuda_json_init_repository",
248248
)
249249

@@ -255,7 +255,7 @@ load(
255255
"CUDNN_REDISTRIBUTIONS",
256256
)
257257
load(
258-
"@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
258+
"@xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
259259
"cuda_redist_init_repositories",
260260
"cudnn_redist_init_repository",
261261
)
@@ -269,21 +269,21 @@ cudnn_redist_init_repository(
269269
)
270270

271271
load(
272-
"@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
272+
"@xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
273273
"cuda_configure",
274274
)
275275

276276
cuda_configure(name = "local_config_cuda")
277277

278278
load(
279-
"@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
279+
"@xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
280280
"nccl_redist_init_repository",
281281
)
282282

283283
nccl_redist_init_repository()
284284

285285
load(
286-
"@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
286+
"@xla//third_party/nccl/hermetic:nccl_configure.bzl",
287287
"nccl_configure",
288288
)
289289

0 commit comments

Comments
 (0)