@@ -4000,28 +4000,6 @@ def body(ctx, param: ir.Value, result: ir.Value, smem: list[ir.Value]):
40004000 param = self .prng .uniform (- 1 , 1 , shape ).astype (dtype )
40014001 self .assertArraysEqual (kernel (param ), param )
40024002
4003- def test_optimized_gmem_transfers_are_not_supported (self ):
4004- def body (ctx , input , output , scratch ):
4005- del ctx , output , scratch
4006- reg = vector_load (input , optimized = True )
4007- layout = layouts .to_layout_attr (fa .WGMMA_LAYOUT )
4008- reg = mgpu_dialect .layout_cast (reg , layout )
4009-
4010- shape = (128 , 128 )
4011- dtype = jnp .bfloat16
4012- with self .assertRaisesRegex (
4013- NotImplementedError , "Only optimized transfers to SMEM supported"
4014- ):
4015- mgpu .as_gpu_kernel (
4016- body ,
4017- grid = (1 , 1 , 1 ),
4018- block = (128 , 1 , 1 ),
4019- in_shape = jax .ShapeDtypeStruct (shape , dtype ),
4020- out_shape = jax .ShapeDtypeStruct (shape , dtype ),
4021- smem_scratch_shape = (),
4022- thread_semantics = mgpu .LoweringSemantics .Warpgroup ,
4023- )
4024-
40254003 def test_pointwise_kernel (self ):
40264004 def add (ctx , a , b , result , smem ):
40274005 del ctx , smem
@@ -4321,33 +4299,6 @@ def body(ctx, result_gmem_ref, scratch):
43214299 kernel (), jax .lax .broadcast_in_dim (x , output_shape , bcast_dims )
43224300 )
43234301
4324- def test_bad_layout_cast_raises_in_inference (self ):
4325- shape = (128 , 128 )
4326- def body (ctx , out , _ ):
4327- del ctx , out
4328- f32 = ir .F32Type .get ()
4329- x = vector .broadcast (
4330- ir .VectorType .get (shape , f32 ), arith .constant (f32 , 0.0 )
4331- )
4332- wgmma_layout = layouts .to_layout_attr (fa .WGMMA_LAYOUT )
4333- wgmma_row_layout = layouts .to_layout_attr (fa .WGMMA_ROW_LAYOUT )
4334- lc1 = mgpu_dialect .layout_cast (x , wgmma_layout )
4335- mgpu_dialect .layout_cast (lc1 , wgmma_row_layout )
4336-
4337- dtype = jnp .float32
4338- with self .assertRaisesRegex (
4339- ValueError , "user-provided layout casts are unsatisfiable"
4340- ):
4341- mgpu .as_gpu_kernel (
4342- body ,
4343- grid = (1 , 1 , 1 ),
4344- block = (128 , 1 , 1 ),
4345- in_shape = (),
4346- out_shape = jax .ShapeDtypeStruct (shape , dtype ),
4347- smem_scratch_shape = (),
4348- thread_semantics = mgpu .LoweringSemantics .Warpgroup ,
4349- )
4350-
43514302 @parameterized .parameters (
43524303 (jnp .float32 , 5.0 , 2.0 , vector .CombiningKind .ADD ),
43534304 (jnp .float32 , 5.0 , 2.0 , vector .CombiningKind .MAXIMUMF ),
@@ -5240,32 +5191,6 @@ def matmul(ctx, a_gmem, b_gmem, result_gmem, scratch):
52405191 rtol = rtol ,
52415192 )
52425193
5243- def test_inconsistent_collective_attributes_in_kernel_raise (self ):
5244- def body (ctx , out , smem_ptr ):
5245- del ctx , out
5246- ref_ty = ir .MemRefType .get (
5247- (128 , 128 ),
5248- ir .BF16Type .get (),
5249- memory_space = utils .tmem (),
5250- )
5251- mgpu_dialect .tmem_alloc (ref_ty , smem_ptr , collective = False )
5252- mgpu_dialect .tmem_alloc (ref_ty , smem_ptr , collective = True )
5253-
5254- with self .assertRaisesRegex (
5255- ValueError ,
5256- "Collective attributes are inconsistent across operations in the"
5257- " kernel" ,
5258- ):
5259- mgpu .as_gpu_kernel (
5260- body ,
5261- grid = (1 , 1 , 1 ),
5262- block = (128 , 1 , 1 ),
5263- in_shape = (),
5264- out_shape = (jax .ShapeDtypeStruct ((), jnp .int32 ),),
5265- smem_scratch_shape = jax .ShapeDtypeStruct ((), jnp .int32 ),
5266- thread_semantics = mgpu .LoweringSemantics .Warpgroup ,
5267- )
5268-
52695194 def test_slice_tmem (self ):
52705195 def tmem_type (ref : ir .Value ):
52715196 return ir .MemRefType .get (
0 commit comments