@@ -14,6 +14,7 @@ const map_ptx_to_jl_array = Dict(
1414 " s8" => Int8,
1515 " s32" => Int32,
1616 " f16" => Float16,
17+ " tf32" => Float32,
1718 " f32" => Float32
1819 )
1920
@@ -23,6 +24,7 @@ const map_ptx_to_jl_frag = Dict(
2324 " s8" => UInt32,
2425 " s32" => Int32,
2526 " f16" => NTuple{2 , VecElement{Float16}},
27+ " tf32" => Float32,
2628 " f32" => Float32
2729 )
2830
@@ -40,6 +42,8 @@ const map_frag_sizes = Dict(
4042 " a.f16.m16n16k16" => 8 ,
4143 " a.f16.m8n32k16" => 8 ,
4244 " a.f16.m32n8k16" => 8 ,
45+
46+ " a.tf32.m16n16k8" => 8 ,
4347 # B
4448 " b.u8.m16n16k16" => 2 ,
4549 " b.u8.m8n32k16" => 4 ,
@@ -52,6 +56,8 @@ const map_frag_sizes = Dict(
5256 " b.f16.m16n16k16" => 8 ,
5357 " b.f16.m8n32k16" => 8 ,
5458 " b.f16.m32n8k16" => 8 ,
59+
60+ " b.tf32.m16n16k8" => 8 ,
5561 # C
5662 " c.s32.m16n16k16" => 8 ,
5763 " c.s32.m8n32k16" => 8 ,
@@ -64,6 +70,8 @@ const map_frag_sizes = Dict(
6470 " c.f32.m16n16k16" => 8 ,
6571 " c.f32.m8n32k16" => 8 ,
6672 " c.f32.m32n8k16" => 8 ,
73+
74+ " c.f32.m16n16k8" => 8 ,
6775 # D
6876 " d.s32.m16n16k16" => 8 ,
6977 " d.s32.m8n32k16" => 8 ,
@@ -76,6 +84,8 @@ const map_frag_sizes = Dict(
7684 " d.f32.m16n16k16" => 8 ,
7785 " d.f32.m8n32k16" => 8 ,
7886 " d.f32.m32n8k16" => 8 ,
87+
88+ " d.f32.m16n16k8" => 8 ,
7989 )
8090
8191# Maps PTX AS to CUDA.AS
@@ -87,6 +97,10 @@ const map_ptx_as_to_as_ty = Dict(
8797
8898# Valid WMMA Operation configurations: Shape (M,N,K), Matrix, Element Type
8999
100+ # TF32-Precision Floating Point
101+ const ldst_tf32_ab_ops = [(16 ,16 ,8 )], [" a" , " b" ], [" tf32" ]
102+ const ldst_tf32_cd_ops = [(16 ,16 ,8 )], [" c" , " d" ], [" f32" ]
103+ const wmma_tf32_ops = [(16 ,16 ,8 )], [" tf32" ], [" f32" ], [" f32" ]
90104# Half-Precision Floating Point
91105const ldst_half_ab_ops = [(16 ,16 ,16 ), (32 ,8 ,16 ), (8 ,32 ,16 )], [" a" , " b" ], [" f16" ]
92106const ldst_half_cd_ops = [(16 ,16 ,16 ), (32 ,8 ,16 ), (8 ,32 ,16 )], [" c" , " d" ], [" f16" , " f32" ]
@@ -97,11 +111,12 @@ const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"]
97111const wmma_int_ops = [(16 ,16 ,16 ), (32 ,8 ,16 ), (8 ,32 ,16 )], [" s8" , " u8" ], [" s32" ], [" s32" ]
98112
99113const all_ldst_ops = vcat (ldst_half_ab_ops, ldst_half_cd_ops,
100- ldst_int_ab_ops, ldst_int_cd_ops)
101- const all_wmma_ops = vcat (wmma_half_ops, wmma_int_ops)
114+ ldst_int_ab_ops, ldst_int_cd_ops,
115+ ldst_tf32_ab_ops, ldst_tf32_cd_ops)
116+ const all_wmma_ops = vcat (wmma_half_ops, wmma_int_ops, wmma_tf32_ops)
102117
103118# Valid WMMA operation shapes
104- const valid_shapes = [(16 , 16 , 16 ), (32 , 8 , 16 ), (8 , 32 , 16 )]
119+ const valid_shapes = [(16 , 16 , 16 ), (32 , 8 , 16 ), (8 , 32 , 16 ), ( 16 , 16 , 8 ) ]
105120
106121# ###############################################################################
107122# HELPER FUNCTIONS
0 commit comments