|
73 | 73 | from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
|
74 | 74 | from jax._src.util import (cache, canonicalize_axis,
|
75 | 75 | safe_map, safe_zip, split_list, weakref_lru_cache,
|
76 |
| - foreach) |
| 76 | + foreach, tuple_insert) |
77 | 77 |
|
78 | 78 | _max = builtins.max
|
79 | 79 | _min = builtins.min
|
@@ -9071,6 +9071,15 @@ def _empty_lower(ctx, *, shape, dtype, out_sharding):
|
9071 | 9071 | return [mlir.lower_with_sharding_in_types(ctx, out, phys_aval)]
|
9072 | 9072 | mlir.register_lowering(empty_p, _empty_lower)
|
9073 | 9073 |
|
| 9074 | +def _empty_batcher(axis_data, vals_in, dims_in, *, shape, dtype, out_sharding): |
| 9075 | + batched_shape = tuple_insert(shape, 0, axis_data.size) |
| 9076 | + batched_out_sharding = ( |
| 9077 | + None if out_sharding is None else |
| 9078 | + batching.get_sharding_for_vmap(axis_data, out_sharding, 0)) |
| 9079 | + y = empty_p.bind(shape=batched_shape, dtype=dtype, |
| 9080 | + out_sharding=batched_out_sharding) |
| 9081 | + return y, 0 |
| 9082 | +batching.fancy_primitive_batchers[empty_p] = _empty_batcher |
9074 | 9083 |
|
9075 | 9084 | # TODO(yashkatariya): Delete `empty2` and replace scan's usage with `empty` once
|
9076 | 9085 | # AllocateBuffer issues are fixed
|
|
0 commit comments