Skip to content

Commit dc25c56

Browse files
[Mosaic GPU][NFC] Specify the memory space for two existing tests.
This will become relevant later once we enable SMEM inference by default. PiperOrigin-RevId: 813743079
1 parent 8f490a9 commit dc25c56

File tree

2 files changed

+108
-2
lines changed

2 files changed

+108
-2
lines changed

jax/experimental/mosaic/gpu/layout_inference.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,17 @@ def _async_store_tmem_equation_system(
11251125
)
11261126

11271127

1128+
@_add_equation_system_derivation_rule(mgpu.SliceSMEMOp)
1129+
def _slice_smem_equation_system(
1130+
ctx: DerivationContext,
1131+
op: mgpu.SliceSMEMOp,
1132+
) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]:
1133+
del ctx
1134+
res = OperandOrResult(op, VariableType.RESULT, 0)
1135+
res_var = eqns.Variable(res)
1136+
return (eqns.EquationSystem(), {res_var: [res]}, [])
1137+
1138+
11281139
# `memref.load` and `memref.store` are used to load barrier phases which are
11291140
# scalars---the rule needn't do anything interesting, but we need to have it.
11301141
@_add_equation_system_derivation_rule(memref.LoadOp)
@@ -1148,6 +1159,43 @@ def _memref_load_store_op_equation_system(
11481159
return eqns.EquationSystem(assignments=assignments), {var: [ref]}, []
11491160

11501161

1162+
@_add_equation_system_derivation_rule(mgpu.WithTransformsOp)
1163+
def _with_transforms_equation_system(
1164+
ctx: DerivationContext,
1165+
op: mgpu.WithTransformsOp,
1166+
) -> tuple[eqns.EquationSystem, OperandOrResultsForVariable, list[Hint]]:
1167+
source = OperandOrResult(op, VariableType.OPERAND, 0)
1168+
dest = OperandOrResult(op, VariableType.RESULT, 0)
1169+
var = ctx.producer_ref(source)
1170+
1171+
transforms = [layouts_lib.from_transform_attr(x) for x in op.transforms]
1172+
match transforms:
1173+
case []:
1174+
tile_transform = None
1175+
swizzle = None
1176+
case [lc.TileTransform() as t]:
1177+
tile_transform = t
1178+
swizzle = None
1179+
case [lc.TileTransform() as t, mgpu.SwizzlingMode() as s]:
1180+
tile_transform = t
1181+
swizzle = s
1182+
case _:
1183+
raise NotImplementedError(f"Unsupported transforms {transforms}")
1184+
1185+
if swizzle is not None:
1186+
computed_swizzle = _compute_swizzle(op.ref.type, tile_transform)
1187+
if computed_swizzle != swizzle:
1188+
raise NotImplementedError(
1189+
f"Cannot honor caller-provided swizzle {swizzle} that is different "
1190+
f"from the computed swizle {computed_swizzle} on op {op}."
1191+
)
1192+
1193+
assignments: dict[eqns.Variable, eqns.Constant] = {
1194+
var: eqns.SMEMTiling(tile_transform)
1195+
}
1196+
return eqns.EquationSystem(assignments=assignments), {var: [source, dest]}, []
1197+
1198+
11511199
def _ensure_all_layouts_are_set(
11521200
op: ir.OpView, enable_smem_inference: bool
11531201
) -> None:

tests/mosaic/gpu_layout_inference_test.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,9 @@ def test_infer_layout_from_body_op_to_yield_op_to_for_op(self):
413413
shape = (64, 64)
414414
with ir.InsertionPoint(self.module.body):
415415
c_ty = ir.VectorType.get(shape, ir.BF16Type.get())
416-
ab_type = ir.MemRefType.get(shape, ir.BF16Type.get())
416+
ab_type = ir.MemRefType.get(
417+
shape, ir.BF16Type.get(), memory_space=mgpu.utils.smem()
418+
)
417419
i32 = ir.IntegerType.get_signless(32)
418420
lower_bound, upper_bound, step, a, b, c = undefs(
419421
i32, i32, i32, ab_type, ab_type, c_ty
@@ -794,7 +796,7 @@ def test_infer_wgmma_layout_correctly(self, lhs_memory_space):
794796

795797
with ir.InsertionPoint(self.module.body):
796798
vec_ty = ir.VectorType.get(shape, f32)
797-
ref_ty = ir.MemRefType.get(shape, f32)
799+
ref_ty = ir.MemRefType.get(shape, f32, memory_space=mgpu.utils.smem())
798800
lhs_ty = ref_ty if lhs_memory_space == "shared" else vec_ty
799801
acc, lhs, rhs = undefs(vec_ty, lhs_ty, ref_ty)
800802
wgmma_op = mgpu.dialect.WGMMAOp(acc, lhs, rhs)
@@ -1253,6 +1255,62 @@ def test_memref_load_store_op_transforms_are_empty(self):
12531255
self.assertEqual(inference_utils.in_transforms(load_op), want)
12541256
self.assertEqual(inference_utils.in_transforms(store_op), want)
12551257

1258+
def test_slice_smem_gets_empty_by_default(self):
1259+
with ir.InsertionPoint(self.module.body):
1260+
shape = (64, 64)
1261+
elt_ty = ir.BF16Type.get()
1262+
i32 = ir.IntegerType.get_signless(32)
1263+
[offset] = undefs(i32)
1264+
ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem())
1265+
slice_smem_op = mgpu.dialect.SliceSMEMOp(ref_ty, offset)
1266+
1267+
transforms = ir.ArrayAttr.get([])
1268+
mgpu.infer_layout(self.module, enable_smem_inference=True)
1269+
self.assertSequenceEqual(
1270+
inference_utils.out_transforms(slice_smem_op), [transforms]
1271+
)
1272+
1273+
def test_infer_transforms_preserves_with_transforms_requirements(self):
1274+
shape = (64, 64)
1275+
elt_ty = ir.BF16Type.get()
1276+
1277+
with ir.InsertionPoint(self.module.body):
1278+
ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem())
1279+
[ref] = undefs(ref_ty)
1280+
1281+
transforms = ir.ArrayAttr.get([
1282+
mgpu.dialect.TileTransformAttr.get((8, 64)),
1283+
mgpu.dialect.SwizzleTransformAttr.get(128),
1284+
])
1285+
mgpu.dialect.with_transforms(ref, transforms)
1286+
1287+
mgpu.infer_layout(self.module, enable_smem_inference=True)
1288+
self.assertSequenceEqual(
1289+
inference_utils.out_transforms(ref.owner), [transforms]
1290+
)
1291+
1292+
def test_infer_transforms_fails_on_conflicting_with_transforms_requirements(self):
1293+
shape = (64, 64)
1294+
elt_ty = ir.BF16Type.get()
1295+
1296+
with ir.InsertionPoint(self.module.body):
1297+
ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem())
1298+
[ref] = undefs(ref_ty)
1299+
1300+
transforms1 = ir.ArrayAttr.get([
1301+
mgpu.dialect.TileTransformAttr.get((8, 64)),
1302+
mgpu.dialect.SwizzleTransformAttr.get(128),
1303+
])
1304+
transforms2 = ir.ArrayAttr.get([
1305+
mgpu.dialect.TileTransformAttr.get((16, 64)),
1306+
mgpu.dialect.SwizzleTransformAttr.get(128),
1307+
])
1308+
mgpu.dialect.with_transforms(ref, transforms1)
1309+
mgpu.dialect.with_transforms(ref, transforms2)
1310+
1311+
with self.assertRaisesRegex(ValueError, "Failed to infer"):
1312+
mgpu.infer_layout(self.module, enable_smem_inference=True)
1313+
12561314

12571315
if __name__ == "__main__":
12581316
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)