55包含以下内容:
66
77- [X] softmax_f32_kernel (grid level memory fence)
8- - [X] softmax_f32x4_kernel(grid level memory fence, float4向量化版本 )
8+ - [X] softmax_f32x4_kernel(grid level memory fence)
99- [X] softmax_f32_per_token_kernel(per token)
10- - [X] softmax_f32x4_per_token_kernel(per token, float4向量化版本 )
10+ - [X] softmax_f32x4_per_token_kernel(per token)
1111- [X] safe_softmax_f32_per_token_kernel(per token)
12- - [X] safe_softmax_f32x4_per_token_kernel(per token, float4向量化版本)
12+ - [X] safe_softmax_f32x4_per_token_kernel(per token)
13+ - [X] safe_softmax_f16_f32_per_token_kernel(per token)
14+ - [X] safe_softmax_f16x2_f32_per_token_kernel(per token)
15+ - [X] safe_softmax_f16x8_pack_f32_per_token_kernel(per token)
1316- [X] PyTorch bindings
1417
1518
@@ -24,25 +27,84 @@ python3 softmax.py
2427输出:
2528
2629``` bash
27- --------------------------------------------------------------------------------
28- out_f32: [1.909e-05, 0.00023536, 0.00010881], time:0.01697016ms
29- out_f32x4: [1.909e-05, 0.00023536, 0.00010881], time:0.01716042ms
30- out_f32_th: [1.909e-05, 0.00023536, 0.00010881], time:0.00715089ms
31- --------------------------------------------------------------------------------
32- out_f32(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.01011539ms
33- out_f32x4(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.01006842ms
34- out_f32_th(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.00547409ms
35- --------------------------------------------------------------------------------
36- out_f32(per): [0.00569158, 0.00022239, 0.00137839], time:0.01047754ms
37- out_f32x4(per): [0.00569158, 0.00022239, 0.00137839], time:0.01045704ms
38- out_f32(safe): [0.00569158, 0.00022239, 0.00137839], time:0.01054454ms
39- out_f32x4(safe): [0.00569158, 0.00022239, 0.00137839], time:0.01042986ms
40- out_f32_th(per): [0.00569158, 0.00022239, 0.00137839], time:0.00741696ms
41- --------------------------------------------------------------------------------
42- out_f32(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00419974ms
43- out_f32x4(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00316834ms
44- out_f32(safe v2): [0.00569158, 0.00022239, 0.00137839], time:0.00603890ms
45- out_f32x4(safe v2): [0.00569158, 0.00022239, 0.00137839], time:0.00319862ms
46- out_f32_th(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00577068ms
47- --------------------------------------------------------------------------------
48- ```
30+ ----------------------------------------------------------------------------------------------------
31+ N=16384
32+ ----------------------------------------------------------------------------------------------------
33+ out_f32(fence): [' 5.912e-05 ' , ' 9.61e-05 ' , ' 4.271e-05 ' ], time:0.01040053ms
34+ out_f32x4(fence): [' 5.912e-05 ' , ' 9.61e-05 ' , ' 4.271e-05 ' ], time:0.01053643ms
35+ out_f32_th: [' 5.912e-05 ' , ' 9.61e-05 ' , ' 4.271e-05 ' ], time:0.00582504ms
36+ ----------------------------------------------------------------------------------------------------
37+ S=4096, H=256
38+ ----------------------------------------------------------------------------------------------------
39+ out_f32(per): [' 0.0015298 ' , ' 0.00619088 ' , ' 0.00529766 ' ], time:0.00627208ms
40+ out_f32x4(per): [' 0.0015298 ' , ' 0.00619088 ' , ' 0.00529766 ' ], time:0.00394082ms
41+ out_f32(safe): [' 0.0015298 ' , ' 0.00619088 ' , ' 0.00529766 ' ], time:0.00941491ms
42+ out_f32x4(safe): [' 0.0015298 ' , ' 0.00619088 ' , ' 0.00529766 ' ], time:0.00413442ms
43+ out_f32_th(per): [' 0.0015298 ' , ' 0.00619088 ' , ' 0.00529766 ' ], time:0.00602674ms
44+ ----------------------------------------------------------------------------------------------------
45+ out_f16f32(safe): [' 0.00152969 ' , ' 0.00619125 ' , ' 0.00529861 ' ], time:0.00912046ms
46+ out_f16x2f32(safe): [' 0.00152969 ' , ' 0.00619125 ' , ' 0.00529861 ' ], time:0.00522232ms
47+ out_f16x8packf32(safe): [' 0.00152969 ' , ' 0.00619125 ' , ' 0.00529861 ' ], time:0.00413895ms
48+ out_f16_th(per): [' 0.00152969 ' , ' 0.00619125 ' , ' 0.00529861 ' ], time:0.00605321ms
49+ ----------------------------------------------------------------------------------------------------
50+ ----------------------------------------------------------------------------------------------------
51+ S=4096, H=512
52+ ----------------------------------------------------------------------------------------------------
53+ out_f32(per): [' 0.00200376 ' , ' 0.00063461 ' , ' 0.00163568 ' ], time:0.01139641ms
54+ out_f32x4(per): [' 0.00200376 ' , ' 0.00063461 ' , ' 0.00163568 ' ], time:0.00515914ms
55+ out_f32(safe): [' 0.00200376 ' , ' 0.00063461 ' , ' 0.00163568 ' ], time:0.01834297ms
56+ out_f32x4(safe): [' 0.00200376 ' , ' 0.00063461 ' , ' 0.00163568 ' ], time:0.00574923ms
57+ out_f32_th(per): [' 0.00200376 ' , ' 0.00063461 ' , ' 0.00163568 ' ], time:0.00657558ms
58+ ----------------------------------------------------------------------------------------------------
59+ out_f16f32(safe): [' 0.00200462 ' , ' 0.00063467 ' , ' 0.00163555 ' ], time:0.01782560ms
60+ out_f16x2f32(safe): [' 0.00200462 ' , ' 0.00063467 ' , ' 0.00163555 ' ], time:0.00919509ms
61+ out_f16x8packf32(safe): [' 0.00200462 ' , ' 0.00063467 ' , ' 0.00163555 ' ], time:0.00415683ms
62+ out_f16_th(per): [' 0.00200462 ' , ' 0.00063467 ' , ' 0.00163555 ' ], time:0.00634599ms
63+ ----------------------------------------------------------------------------------------------------
64+ ----------------------------------------------------------------------------------------------------
65+ S=4096, H=1024
66+ ----------------------------------------------------------------------------------------------------
67+ out_f32(per): [' 0.0009461 ' , ' 0.00073918 ' , ' 0.00074397 ' ], time:0.03191805ms
68+ out_f32x4(per): [' 0.0009461 ' , ' 0.00073918 ' , ' 0.00074397 ' ], time:0.00862813ms
69+ out_f32(safe): [' 0.0009461 ' , ' 0.00073918 ' , ' 0.00074397 ' ], time:0.04873967ms
70+ out_f32x4(safe): [' 0.0009461 ' , ' 0.00073918 ' , ' 0.00074397 ' ], time:0.01027441ms
71+ out_f32_th(per): [' 0.0009461 ' , ' 0.00073918 ' , ' 0.00074397 ' ], time:0.01181388ms
72+ ----------------------------------------------------------------------------------------------------
73+ out_f16f32(safe): [' 0.00094604 ' , ' 0.0007391 ' , ' 0.00074387 ' ], time:0.04671884ms
74+ out_f16x2f32(safe): [' 0.00094604 ' , ' 0.0007391 ' , ' 0.00074387 ' ], time:0.01810408ms
75+ out_f16x8packf32(safe): [' 0.00094604 ' , ' 0.0007391 ' , ' 0.00074387 ' ], time:0.00601912ms
76+ out_f16_th(per): [' 0.00094604 ' , ' 0.0007391 ' , ' 0.00074387 ' ], time:0.01047063ms
77+ ----------------------------------------------------------------------------------------------------
78+ ----------------------------------------------------------------------------------------------------
79+ S=4096, H=2048
80+ ----------------------------------------------------------------------------------------------------
81+ out_f32x4(per): [' 9.216e-05 ' , ' 0.00045569 ' , ' 0.00013162 ' ], time:0.01605988ms
82+ out_f32x4(safe): [' 9.216e-05 ' , ' 0.00045569 ' , ' 0.00013162 ' ], time:0.02089310ms
83+ out_f32_th(per): [' 9.216e-05 ' , ' 0.00045569 ' , ' 0.00013162 ' ], time:0.06726241ms
84+ ----------------------------------------------------------------------------------------------------
85+ out_f16x2f32(safe): [' 9.215e-05 ' , ' 0.00045562 ' , ' 0.00013161 ' ], time:0.04824972ms
86+ out_f16x8packf32(safe): [' 9.215e-05 ' , ' 0.00045562 ' , ' 0.00013161 ' ], time:0.01086283ms
87+ out_f16_th(per): [' 9.215e-05 ' , ' 0.00045562 ' , ' 0.00013161 ' ], time:0.07232165ms
88+ ----------------------------------------------------------------------------------------------------
89+ ----------------------------------------------------------------------------------------------------
90+ S=4096, H=4096
91+ ----------------------------------------------------------------------------------------------------
92+ out_f32x4(per): [' 0.00017665 ' , ' 0.00035685 ' , ' 0.00017236 ' ], time:0.18465948ms
93+ out_f32x4(safe): [' 0.00017665 ' , ' 0.00035685 ' , ' 0.00017236 ' ], time:0.18565655ms
94+ out_f32_th(per): [' 0.00017665 ' , ' 0.00035685 ' , ' 0.00017236 ' ], time:0.18744922ms
95+ ----------------------------------------------------------------------------------------------------
96+ out_f16x8packf32(safe): [' 0.00017667 ' , ' 0.00035691 ' , ' 0.00017238 ' ], time:0.02254891ms
97+ out_f16_th(per): [' 0.00017667 ' , ' 0.00035691 ' , ' 0.00017238 ' ], time:0.08283138ms
98+ ----------------------------------------------------------------------------------------------------
99+ ----------------------------------------------------------------------------------------------------
100+ S=4096, H=8192
101+ ----------------------------------------------------------------------------------------------------
102+ out_f16x8packf32(safe): [' 4.166e-05 ' , ' 3.767e-05 ' , ' 1.562e-05 ' ], time:0.19313049ms
103+ out_f16_th(per): [' 4.166e-05 ' , ' 3.767e-05 ' , ' 1.562e-05 ' ], time:0.19356799ms
104+ ----------------------------------------------------------------------------------------------------
105+ S=8192, H=8192
106+ ----------------------------------------------------------------------------------------------------
107+ out_f16x8packf32(safe): [' 4.208e-05 ' , ' 0.00015438 ' , ' 7.409e-05 ' ], time:0.39828229ms
108+ out_f16_th(per): [' 4.208e-05 ' , ' 0.00015438 ' , ' 7.409e-05 ' ], time:0.40599036ms
109+ ----------------------------------------------------------------------------------------------------
110+ ```
0 commit comments