@@ -6,8 +6,9 @@ map_ptx_to_jl_frag = Dict(
66 " u32" => UInt32 (42 ),
77 " s32" => Int32 (42 ),
88 " f16" => ntuple (i -> VecElement {Float16} (42 ), 2 ),
9- " f32" => Float32 (42 )
10- )
9+ " f32" => Float32 (42 ),
10+ " tf32" => Float32 (42 )
11+ )
1112# Return specific matrix shape given operation configuration
1213function get_array_shape (mat, mnk, layout)
1314 if ! (mat in [" a" ," b" ," c" ," d" ])
4647 # Type-dependent variables
4748 array_ty = CUDA. WMMA. map_ptx_to_jl_array[elem_type]
4849 expected = map_ptx_to_jl_frag[elem_type]
49-
50+
5051 # Address-space dependent variables
5152 do_shared_test = (addr_space == " _shared" )
5253
5354 # Get the function name
5455 func = Symbol (" llvm_wmma_load_$(mat) _$(layout) _$(shape)$(addr_space) _stride_$(elem_type) " )
55-
56+
5657 input_shape = get_array_shape (mat, mnk, layout)
5758 input = array_ty (42 ) * ones (array_ty, input_shape)
5859 input_dev = CuArray (input)
9697 elem_type in ops[3 ],
9798 addr_space in [" " , " _global" , " _shared" ],
9899 stride in [" stride" ]
99-
100+
100101 # Skip all but d matrices
101102 if mat != " d"
102103 continue
169170 ldc_func = getfield (Main, Symbol (" llvm_wmma_load_c_col_$(shape) _global_stride_$(c_elem_type) " ))
170171 # Account for half and int/subint mma different naming conventions
171172 # Int/subint mma functions are distinguished by the a/b element type
172- mma_sym = d_ty == Int32 ? Symbol (" llvm_wmma_mma_$(a_layout) _$(b_layout) _$(shape) _$(ab_elem_type) " ) :
173+ mma_sym = ( d_ty == Int32 || ab_elem_type == " tf32 " ) ? Symbol (" llvm_wmma_mma_$(a_layout) _$(b_layout) _$(shape) _$(ab_elem_type) " ) :
173174 Symbol (" llvm_wmma_mma_$(a_layout) _$(b_layout) _$(shape) _$(d_elem_type) _$(c_elem_type) " )
174- mma_func = getfield (Main, mma_sym)
175+ mma_func = getfield (Main, mma_sym)
175176 std_func = getfield (Main, Symbol (" llvm_wmma_store_d_col_$(shape) _global_stride_$(d_elem_type) " ))
176177
177178 a_shape = get_array_shape (" a" , mnk, a_layout)
205206 new_a = (a_layout == " col" ? a : transpose (a))
206207 new_b = (b_layout == " col" ? b : transpose (b))
207208 # Alter test depending on a/b element Type
208- if ab_ty == Float16
209+ if ab_ty == Float16 || ab_elem_type == " tf32 "
209210 @test new_a * new_b + c ≈ Array (d_dev) rtol= Base. rtoldefault (Float16)
210- else # Cast a and b to prevent UInt8 rollover of resultant data
211+ else # Cast a and b to prevent UInt8 rollover of resultant data
211212 @test Int32 .(new_a) * Int32 .(new_b) + c == Array (d_dev)
212213 end
213214 end
344345 @test ! occursin (r" wmma.store.d.sync(.aligned)?.col.m16n16k16.f32" , ptx)
345346 @test occursin (r" wmma.store.d.sync(.aligned)?.col.m16n16k16.shared.f32" , ptx)
346347 end
347- end
348+ end
0 commit comments