Skip to content

Commit 33c7a7e

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Improve checks on the WG strided transfers
Our previous code didn't actually verify that the vector loads/stores are safe, so you could pass in a non-contiguous reference with a weird shape, but the right number of elements and get nonsensical results. The current check is a bit too conservative, but it's better to lean this way. This also factors the address calculation into a common `transfer_strided` class method, following the same pattern we use for tiled layouts. PiperOrigin-RevId: 814163433
1 parent 564c8c4 commit 33c7a7e

File tree

3 files changed

+102
-40
lines changed

3 files changed

+102
-40
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -966,16 +966,11 @@ def load_strided(
966966
)
967967
else:
968968
layout = WGStridedFragLayout(shape=shape, vec_size=vec_size)
969+
registers = np.empty(layout.registers_shape(shape), dtype=object)
969970
vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type)
970-
try:
971-
# Flattening the reference potentially produces simpler PTX but
972-
# if the ref is not already 1D and has strided dimensions
973-
# flattening won't work.
974-
ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
975-
vecs = [vector.load(vec_ty, ref_, [vec_idx]) for vec_idx in layout.linear_thread_idxs()]
976-
except NotImplementedError:
977-
vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)]
978-
return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed)
971+
for _get, update, ref, idx in cls.transfer_strided(ref, layout.vec_size):
972+
update(registers, vector.load(vec_ty, ref, idx))
973+
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
979974

980975
@classmethod
981976
def splat(
@@ -2579,8 +2574,10 @@ def store_untiled(
25792574
if isinstance(ref, utils.MultimemRef):
25802575
raise NotImplementedError("Strided layout does not support multimem")
25812576
if swizzle != 16:
2582-
raise NotImplementedError
2583-
self._store_untiled_wg_strided(ref)
2577+
raise ValueError("Only TiledLayouts support swizzling")
2578+
assert isinstance(self.layout, WGStridedFragLayout)
2579+
for get, _update, ref, idx in self.transfer_strided(ref, self.layout.vec_size):
2580+
vector.store(get(self.registers), ref, idx)
25842581
case TiledLayout():
25852582
ref_shape = ir.MemRefType(ref.type).shape
25862583
ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape))
@@ -2621,8 +2618,8 @@ def load_untiled(
26212618
is_signed: bool | None = None,
26222619
optimized: bool = True,
26232620
) -> FragmentedArray:
2624-
ref_shape = ir.MemRefType(ref.type).shape
2625-
ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape))
2621+
ref_ty = ir.MemRefType(ref.type)
2622+
ref = utils.memref_reshape(ref, (*(1 for _ in ref_ty.shape), *ref_ty.shape))
26262623
return cls.load_tiled(
26272624
ref, swizzle=swizzle, is_signed=is_signed, layout=layout, optimized=optimized
26282625
)
@@ -2653,27 +2650,6 @@ def _store_untiled_splat(self, ref: ir.Value):
26532650
)
26542651
fa.store_untiled(ref)
26552652

2656-
def _store_untiled_wg_strided(self, ref: ir.Value):
2657-
assert isinstance(self.layout, WGStridedFragLayout)
2658-
ref_ty = ir.MemRefType(ref.type)
2659-
idxs: Iterable[Sequence[ir.Value]]
2660-
try:
2661-
# Flattening the reference potentially produces simpler PTX but
2662-
# if the ref is not already 1D and has strided dimensions
2663-
# flattening won't work. We use a different variable for ref in
2664-
# case `NotImplementedError` is thrown by
2665-
# .linear_thread_idxs().
2666-
ref_ = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
2667-
idxs = ((i,) for i in self.layout.linear_thread_idxs())
2668-
except NotImplementedError:
2669-
ref_ = ref
2670-
idxs = self.layout.thread_idxs(self.shape)
2671-
ref_shape = tuple(ref_ty.shape)
2672-
if ref_shape != self.shape:
2673-
raise ValueError((ref_shape, self.shape))
2674-
for idx, reg in zip(idxs, self.registers.flat):
2675-
vector.store(reg, ref_, idx)
2676-
26772653
def store_tiled(self, ref: ir.Value | utils.MultimemRef, swizzle: int | None, optimized: bool = True):
26782654
if not isinstance(self.layout, TiledLayout):
26792655
raise NotImplementedError(self.layout)
@@ -2731,6 +2707,51 @@ def load_tiled(
27312707
update(registers, loaded_reg)
27322708
return cls(_registers=registers, _layout=layout, _is_signed=is_signed)
27332709

2710+
@classmethod
2711+
def transfer_strided(self, ref: ir.Value, vec_size: int):
2712+
ref_ty = ir.MemRefType(ref.type)
2713+
layout = WGStridedFragLayout(shape=tuple(ref_ty.shape), vec_size=vec_size)
2714+
try:
2715+
# Flattening the reference potentially produces simpler PTX but
2716+
# if the ref is not already 1D and has strided dimensions
2717+
# flattening won't work.
2718+
ref = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
2719+
except ValueError:
2720+
strides, _ = ref_ty.get_strides_and_offset()
2721+
if vec_size > 1:
2722+
# TODO(apaszke): We could fold all the pairs of dims that are contiguous
2723+
# This check is a too strict if we don't do that.
2724+
has_contiguous_dim = False
2725+
for size, stride in zip(ref_ty.shape, strides):
2726+
if stride == 1:
2727+
has_contiguous_dim = True
2728+
if size % vec_size != 0:
2729+
raise ValueError(
2730+
"The contiguous dimension of the reference must be a"
2731+
f" multiple of the layout's vector size (got {size} and"
2732+
f" vector size {vec_size})"
2733+
) from None
2734+
elif size > 1:
2735+
if stride % vec_size != 0:
2736+
raise ValueError(
2737+
"Non-contiguous dimension of the reference must have strides"
2738+
" that are multiples of the layout's vector size (got"
2739+
f" {stride} and vector size {vec_size})"
2740+
) from None
2741+
if not has_contiguous_dim:
2742+
raise ValueError(
2743+
"The reference must have a contiguous dimension when vec_size > 1"
2744+
)
2745+
idx_gen = layout.thread_idxs(tuple(ref_ty.shape))
2746+
else:
2747+
idx_gen = map(lambda x: [x], layout.linear_thread_idxs())
2748+
for i, vec_idx in enumerate(idx_gen):
2749+
def update(registers, reg, _i=i):
2750+
registers[_i] = reg
2751+
def get(registers, _i=i):
2752+
return registers[_i]
2753+
yield get, update, ref, vec_idx
2754+
27342755
@staticmethod
27352756
def transfer_tiled(
27362757
ref: ir.Value,

jax/experimental/mosaic/gpu/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value:
694694
new_strides[dim : dim + fold_rank] = [new_strides[dim + fold_rank - 1]]
695695
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
696696
else:
697-
raise NotImplementedError(
697+
raise ValueError(
698698
f"strides={ref_ty.get_strides_and_offset()[0]}, {ref_ty.shape=},"
699699
f" {dim=}, {fold_rank=}"
700700
)

tests/mosaic/gpu_test.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def kernel(ctx, inp, out, _):
413413
# ("overap", (2, 4, 4), (16, 1, 1), 0, 3, True),
414414
])
415415
def test_fold_strided(
416-
self, shape, strides, dim, fold_rank, throws_not_impl
416+
self, shape, strides, dim, fold_rank, throws
417417
):
418418
expanded_shape = get_packed_shape(strides, shape)
419419
total_size = np.prod(expanded_shape)
@@ -426,7 +426,7 @@ def np_fold(inp, dim, fold_rank):
426426
out_shape[dim : dim + fold_rank] = [
427427
int(np.prod(inp.shape[dim : dim + fold_rank]))
428428
]
429-
if throws_not_impl:
429+
if throws:
430430
return jax.ShapeDtypeStruct(shape=out_shape, dtype=inp.dtype)
431431
else:
432432
return inp.reshape(*out_shape)
@@ -442,12 +442,12 @@ def kernel(ctx, inp, out, _):
442442
kernel, (1, 1, 1), (128, 1, 1), np_inp, out, ()
443443
)(np_inp)
444444
assert (
445-
not throws_not_impl
445+
not throws
446446
), "If it should have thrown it would during the call."
447447
np.testing.assert_array_equal(y, out)
448448

449-
if throws_not_impl:
450-
with self.assertRaises(NotImplementedError):
449+
if throws:
450+
with self.assertRaises(ValueError):
451451
do_test()
452452
else:
453453
do_test()
@@ -2937,6 +2937,47 @@ def kernel(ctx, dst, _):
29372937
rtol = 4e-6 if approx else 2e-7
29382938
np.testing.assert_allclose(result, np_op(x), atol=atol, rtol=rtol)
29392939

2940+
def test_strided_copy_noncontig_good(self):
2941+
def kernel(ctx, src, dst, _):
2942+
src_slice = mgpu.memref_slice(src, (slice(None), 1))
2943+
mgpu.FragmentedArray.load_strided(src_slice, is_signed=True, vec_size=4).store_untiled(dst)
2944+
2945+
in_shape = jax.ShapeDtypeStruct((32, 2, 32), jnp.int32)
2946+
out_shape = jax.ShapeDtypeStruct((32, 32), jnp.int32)
2947+
2948+
kernel_fn = mgpu.as_gpu_kernel(
2949+
kernel, (1, 1, 1), (128, 1, 1), in_shape, out_shape, ()
2950+
)
2951+
x = np.arange(math.prod(in_shape.shape), dtype=jnp.int32).reshape(in_shape.shape)
2952+
np.testing.assert_array_equal(kernel_fn(x), x[:, 1])
2953+
2954+
def test_strided_copy_noncontig_bad(self):
2955+
def kernel(ctx, src, dst, _):
2956+
src_slice = mgpu.memref_slice(src, (slice(None), 1))
2957+
mgpu.FragmentedArray.load_strided(src_slice, is_signed=True, vec_size=2).store_untiled(dst)
2958+
2959+
out_shape = jax.ShapeDtypeStruct((256, 7), jnp.int32)
2960+
2961+
in_shape = jax.ShapeDtypeStruct((256, 6, 7), jnp.int32)
2962+
msg = (
2963+
"The contiguous dimension of the reference must be a multiple of the"
2964+
" layout's vector size (got 7 and vector size 2)"
2965+
)
2966+
with self.assertRaises(ValueError, msg=msg):
2967+
mgpu.as_gpu_kernel(
2968+
kernel, (1, 1, 1), (128, 1, 1), in_shape, out_shape, ()
2969+
)
2970+
2971+
in_shape = jax.ShapeDtypeStruct((256, 5, 7), jnp.int32)
2972+
msg = (
2973+
"Non-contiguous dimension of the reference must have strides that are"
2974+
" multiples of the layout's vector size (got 35 and vector size 2)"
2975+
)
2976+
with self.assertRaises(ValueError, msg=msg):
2977+
mgpu.as_gpu_kernel(
2978+
kernel, (1, 1, 1), (128, 1, 1), in_shape, out_shape, ()
2979+
)
2980+
29402981
@parameterized.product(
29412982
dtype=[jnp.float32, jnp.int32],
29422983
m=[128],

0 commit comments

Comments
 (0)