Skip to content

Commit 1df8d4c

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU][NFC] Move tests that don't require a GPU to platform agnostic tests.
PiperOrigin-RevId: 827538093
1 parent e7545ec commit 1df8d4c

File tree

3 files changed

+66
-75
lines changed

3 files changed

+66
-75
lines changed

tests/mosaic/gpu_dialect_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,59 @@ def test_memref_transforms_with_transpose(self):
14351435
strides, _ = ty_transformed.get_strides_and_offset()
14361436
self.assertEqual(strides, [512, 4096, 1, 16])
14371437

1438+
def test_optimized_gmem_transfers_are_not_supported(self):
1439+
def body(ctx, input, output, scratch):
1440+
del ctx, output, scratch
1441+
ref_type = ir.MemRefType(input.type)
1442+
zero = arith.constant(ir.IndexType.get(), 0)
1443+
zero_indices = [zero] * len(ref_type.shape)
1444+
vector_type = ir.VectorType.get(ref_type.shape, ref_type.element_type)
1445+
load = vector.LoadOp(vector_type, input, zero_indices)
1446+
load.attributes["optimized"] = ir.BoolAttr.get(True)
1447+
layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT)
1448+
mgpu.dialect.layout_cast(load.result, layout)
1449+
1450+
shape = (128, 128)
1451+
dtype = jnp.bfloat16
1452+
with self.assertRaisesRegex(
1453+
NotImplementedError, "Only optimized transfers to SMEM supported"
1454+
):
1455+
mgpu.as_gpu_kernel(
1456+
body,
1457+
grid=(1, 1, 1),
1458+
block=(128, 1, 1),
1459+
in_shape=jax.ShapeDtypeStruct(shape, dtype),
1460+
out_shape=jax.ShapeDtypeStruct(shape, dtype),
1461+
smem_scratch_shape=(),
1462+
thread_semantics=mgpu.LoweringSemantics.Warpgroup,
1463+
)
1464+
1465+
def test_inconsistent_collective_attributes_in_kernel_raise(self):
1466+
def body(ctx, out, smem_ptr):
1467+
del ctx, out
1468+
ref_ty = ir.MemRefType.get(
1469+
(128, 128),
1470+
ir.BF16Type.get(),
1471+
memory_space=mgpu_utils.tmem(),
1472+
)
1473+
mgpu.dialect.tmem_alloc(ref_ty, smem_ptr, collective=False)
1474+
mgpu.dialect.tmem_alloc(ref_ty, smem_ptr, collective=True)
1475+
1476+
with self.assertRaisesRegex(
1477+
ValueError,
1478+
"Collective attributes are inconsistent across operations in the"
1479+
" kernel",
1480+
):
1481+
mgpu.as_gpu_kernel(
1482+
body,
1483+
grid=(1, 1, 1),
1484+
block=(128, 1, 1),
1485+
in_shape=(),
1486+
out_shape=(jax.ShapeDtypeStruct((), jnp.int32),),
1487+
smem_scratch_shape=jax.ShapeDtypeStruct((), jnp.int32),
1488+
thread_semantics=mgpu.LoweringSemantics.Warpgroup,
1489+
)
1490+
14381491

14391492
if __name__ == "__main__":
14401493
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())

tests/mosaic/gpu_layout_inference_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,19 @@ def test_layout_cast_of_non_splat_constant_to_splat_raises(self):
889889
):
890890
mgpu.infer_layout(self.module)
891891

892+
def test_layout_of_wgmma_layout_to_wgmma_row_layout_raises(self):
893+
with ir.InsertionPoint(self.module.body):
894+
[ref] = undefs(ir.VectorType.get((128, 128), ir.F32Type.get()))
895+
wgmma_layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT)
896+
wgmma_row_layout = layouts.to_layout_attr(fa.WGMMA_ROW_LAYOUT)
897+
ref = mgpu.dialect.layout_cast(ref, wgmma_layout)
898+
mgpu.dialect.layout_cast(ref, wgmma_row_layout)
899+
900+
with self.assertRaisesRegex(
901+
ValueError, "user-provided layout casts are unsatisfiable"
902+
):
903+
mgpu.infer_layout(self.module)
904+
892905
def test_infer_layout_for_tmem_alloc_by_default(self):
893906
f32 = ir.F32Type.get()
894907
i32 = ir.IntegerType.get_signless(32)

tests/mosaic/gpu_test.py

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -4000,28 +4000,6 @@ def body(ctx, param: ir.Value, result: ir.Value, smem: list[ir.Value]):
40004000
param = self.prng.uniform(-1, 1, shape).astype(dtype)
40014001
self.assertArraysEqual(kernel(param), param)
40024002

4003-
def test_optimized_gmem_transfers_are_not_supported(self):
4004-
def body(ctx, input, output, scratch):
4005-
del ctx, output, scratch
4006-
reg = vector_load(input, optimized=True)
4007-
layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT)
4008-
reg = mgpu_dialect.layout_cast(reg, layout)
4009-
4010-
shape = (128, 128)
4011-
dtype = jnp.bfloat16
4012-
with self.assertRaisesRegex(
4013-
NotImplementedError, "Only optimized transfers to SMEM supported"
4014-
):
4015-
mgpu.as_gpu_kernel(
4016-
body,
4017-
grid=(1, 1, 1),
4018-
block=(128, 1, 1),
4019-
in_shape=jax.ShapeDtypeStruct(shape, dtype),
4020-
out_shape=jax.ShapeDtypeStruct(shape, dtype),
4021-
smem_scratch_shape=(),
4022-
thread_semantics=mgpu.LoweringSemantics.Warpgroup,
4023-
)
4024-
40254003
def test_pointwise_kernel(self):
40264004
def add(ctx, a, b, result, smem):
40274005
del ctx, smem
@@ -4321,33 +4299,6 @@ def body(ctx, result_gmem_ref, scratch):
43214299
kernel(), jax.lax.broadcast_in_dim(x, output_shape, bcast_dims)
43224300
)
43234301

4324-
def test_bad_layout_cast_raises_in_inference(self):
4325-
shape = (128, 128)
4326-
def body(ctx, out, _):
4327-
del ctx, out
4328-
f32 = ir.F32Type.get()
4329-
x = vector.broadcast(
4330-
ir.VectorType.get(shape, f32), arith.constant(f32, 0.0)
4331-
)
4332-
wgmma_layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT)
4333-
wgmma_row_layout = layouts.to_layout_attr(fa.WGMMA_ROW_LAYOUT)
4334-
lc1 = mgpu_dialect.layout_cast(x, wgmma_layout)
4335-
mgpu_dialect.layout_cast(lc1, wgmma_row_layout)
4336-
4337-
dtype = jnp.float32
4338-
with self.assertRaisesRegex(
4339-
ValueError, "user-provided layout casts are unsatisfiable"
4340-
):
4341-
mgpu.as_gpu_kernel(
4342-
body,
4343-
grid=(1, 1, 1),
4344-
block=(128, 1, 1),
4345-
in_shape=(),
4346-
out_shape=jax.ShapeDtypeStruct(shape, dtype),
4347-
smem_scratch_shape=(),
4348-
thread_semantics=mgpu.LoweringSemantics.Warpgroup,
4349-
)
4350-
43514302
@parameterized.parameters(
43524303
(jnp.float32, 5.0, 2.0, vector.CombiningKind.ADD),
43534304
(jnp.float32, 5.0, 2.0, vector.CombiningKind.MAXIMUMF),
@@ -5240,32 +5191,6 @@ def matmul(ctx, a_gmem, b_gmem, result_gmem, scratch):
52405191
rtol=rtol,
52415192
)
52425193

5243-
def test_inconsistent_collective_attributes_in_kernel_raise(self):
5244-
def body(ctx, out, smem_ptr):
5245-
del ctx, out
5246-
ref_ty = ir.MemRefType.get(
5247-
(128, 128),
5248-
ir.BF16Type.get(),
5249-
memory_space=utils.tmem(),
5250-
)
5251-
mgpu_dialect.tmem_alloc(ref_ty, smem_ptr, collective=False)
5252-
mgpu_dialect.tmem_alloc(ref_ty, smem_ptr, collective=True)
5253-
5254-
with self.assertRaisesRegex(
5255-
ValueError,
5256-
"Collective attributes are inconsistent across operations in the"
5257-
" kernel",
5258-
):
5259-
mgpu.as_gpu_kernel(
5260-
body,
5261-
grid=(1, 1, 1),
5262-
block=(128, 1, 1),
5263-
in_shape=(),
5264-
out_shape=(jax.ShapeDtypeStruct((), jnp.int32),),
5265-
smem_scratch_shape=jax.ShapeDtypeStruct((), jnp.int32),
5266-
thread_semantics=mgpu.LoweringSemantics.Warpgroup,
5267-
)
5268-
52695194
def test_slice_tmem(self):
52705195
def tmem_type(ref: ir.Value):
52715196
return ir.MemRefType.get(

0 commit comments

Comments
 (0)