Skip to content

Commit 6edab6e

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Reverts a01fbc2
PiperOrigin-RevId: 814008079
1 parent a01fbc2 commit 6edab6e

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

jax/_src/lax/lax.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
7474
from jax._src.util import (cache, canonicalize_axis,
7575
safe_map, safe_zip, split_list, weakref_lru_cache,
76-
foreach)
76+
foreach, tuple_insert)
7777

7878
_max = builtins.max
7979
_min = builtins.min
@@ -9071,6 +9071,15 @@ def _empty_lower(ctx, *, shape, dtype, out_sharding):
90719071
return [mlir.lower_with_sharding_in_types(ctx, out, phys_aval)]
90729072
mlir.register_lowering(empty_p, _empty_lower)
90739073

9074+
def _empty_batcher(axis_data, vals_in, dims_in, *, shape, dtype, out_sharding):
9075+
batched_shape = tuple_insert(shape, 0, axis_data.size)
9076+
batched_out_sharding = (
9077+
None if out_sharding is None else
9078+
batching.get_sharding_for_vmap(axis_data, out_sharding, 0))
9079+
y = empty_p.bind(shape=batched_shape, dtype=dtype,
9080+
out_sharding=batched_out_sharding)
9081+
return y, 0
9082+
batching.fancy_primitive_batchers[empty_p] = _empty_batcher
90749083

90759084
# TODO(yashkatariya): Delete `empty2` and replace scan's usage with `empty` once
90769085
# AllocateBuffer issues are fixed

tests/api_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4467,6 +4467,18 @@ def test_lax_real_empty(self):
44674467
self.assertEqual(out.shape, (2, 2))
44684468
self.assertEqual(out.dtype, jnp.float32)
44694469

4470+
@jtu.run_on_devices('gpu', 'tpu')
4471+
def test_lax_empty_vmap(self):
4472+
inp = np.arange(8, dtype=jnp.int32).reshape(4, 2)
4473+
4474+
def f(x):
4475+
return jax.lax.empty(x.shape, x.dtype)
4476+
4477+
f = jax.jit(jax.vmap(f))
4478+
f(inp) # doesn't crash
4479+
lowered_text = f.lower(inp).as_text()
4480+
self.assertIn('@AllocateBuffer() : () -> tensor<4x2xi32>', lowered_text)
4481+
44704482
def test_leaked_tracer_issue_7613(self):
44714483
# from https://github.com/jax-ml/jax/issues/7613
44724484
import numpy.random as npr

0 commit comments

Comments
 (0)