@@ -413,7 +413,9 @@ def test_infer_layout_from_body_op_to_yield_op_to_for_op(self):
413
413
shape = (64 , 64 )
414
414
with ir .InsertionPoint (self .module .body ):
415
415
c_ty = ir .VectorType .get (shape , ir .BF16Type .get ())
416
- ab_type = ir .MemRefType .get (shape , ir .BF16Type .get ())
416
+ ab_type = ir .MemRefType .get (
417
+ shape , ir .BF16Type .get (), memory_space = mgpu .utils .smem ()
418
+ )
417
419
i32 = ir .IntegerType .get_signless (32 )
418
420
lower_bound , upper_bound , step , a , b , c = undefs (
419
421
i32 , i32 , i32 , ab_type , ab_type , c_ty
@@ -794,7 +796,7 @@ def test_infer_wgmma_layout_correctly(self, lhs_memory_space):
794
796
795
797
with ir .InsertionPoint (self .module .body ):
796
798
vec_ty = ir .VectorType .get (shape , f32 )
797
- ref_ty = ir .MemRefType .get (shape , f32 )
799
+ ref_ty = ir .MemRefType .get (shape , f32 , memory_space = mgpu . utils . smem () )
798
800
lhs_ty = ref_ty if lhs_memory_space == "shared" else vec_ty
799
801
acc , lhs , rhs = undefs (vec_ty , lhs_ty , ref_ty )
800
802
wgmma_op = mgpu .dialect .WGMMAOp (acc , lhs , rhs )
@@ -1253,6 +1255,62 @@ def test_memref_load_store_op_transforms_are_empty(self):
1253
1255
self .assertEqual (inference_utils .in_transforms (load_op ), want )
1254
1256
self .assertEqual (inference_utils .in_transforms (store_op ), want )
1255
1257
1258
+ def test_slice_smem_gets_empty_by_default (self ):
1259
+ with ir .InsertionPoint (self .module .body ):
1260
+ shape = (64 , 64 )
1261
+ elt_ty = ir .BF16Type .get ()
1262
+ i32 = ir .IntegerType .get_signless (32 )
1263
+ [offset ] = undefs (i32 )
1264
+ ref_ty = ir .MemRefType .get (shape , elt_ty , memory_space = mgpu .utils .smem ())
1265
+ slice_smem_op = mgpu .dialect .SliceSMEMOp (ref_ty , offset )
1266
+
1267
+ transforms = ir .ArrayAttr .get ([])
1268
+ mgpu .infer_layout (self .module , enable_smem_inference = True )
1269
+ self .assertSequenceEqual (
1270
+ inference_utils .out_transforms (slice_smem_op ), [transforms ]
1271
+ )
1272
+
1273
+ def test_infer_transforms_preserves_with_transforms_requirements (self ):
1274
+ shape = (64 , 64 )
1275
+ elt_ty = ir .BF16Type .get ()
1276
+
1277
+ with ir .InsertionPoint (self .module .body ):
1278
+ ref_ty = ir .MemRefType .get (shape , elt_ty , memory_space = mgpu .utils .smem ())
1279
+ [ref ] = undefs (ref_ty )
1280
+
1281
+ transforms = ir .ArrayAttr .get ([
1282
+ mgpu .dialect .TileTransformAttr .get ((8 , 64 )),
1283
+ mgpu .dialect .SwizzleTransformAttr .get (128 ),
1284
+ ])
1285
+ mgpu .dialect .with_transforms (ref , transforms )
1286
+
1287
+ mgpu .infer_layout (self .module , enable_smem_inference = True )
1288
+ self .assertSequenceEqual (
1289
+ inference_utils .out_transforms (ref .owner ), [transforms ]
1290
+ )
1291
+
1292
+ def test_infer_transforms_fails_on_conflicting_with_transforms_requirements (self ):
1293
+ shape = (64 , 64 )
1294
+ elt_ty = ir .BF16Type .get ()
1295
+
1296
+ with ir .InsertionPoint (self .module .body ):
1297
+ ref_ty = ir .MemRefType .get (shape , elt_ty , memory_space = mgpu .utils .smem ())
1298
+ [ref ] = undefs (ref_ty )
1299
+
1300
+ transforms1 = ir .ArrayAttr .get ([
1301
+ mgpu .dialect .TileTransformAttr .get ((8 , 64 )),
1302
+ mgpu .dialect .SwizzleTransformAttr .get (128 ),
1303
+ ])
1304
+ transforms2 = ir .ArrayAttr .get ([
1305
+ mgpu .dialect .TileTransformAttr .get ((16 , 64 )),
1306
+ mgpu .dialect .SwizzleTransformAttr .get (128 ),
1307
+ ])
1308
+ mgpu .dialect .with_transforms (ref , transforms1 )
1309
+ mgpu .dialect .with_transforms (ref , transforms2 )
1310
+
1311
+ with self .assertRaisesRegex (ValueError , "Failed to infer" ):
1312
+ mgpu .infer_layout (self .module , enable_smem_inference = True )
1313
+
1256
1314
1257
1315
if __name__ == "__main__" :
1258
1316
parameterized .absltest .main (testLoader = jtu .JaxTestLoader ())
0 commit comments