Skip to content

[PerfXLab] optimize fused_moe#2230

Open
v2yield wants to merge 2 commits intoflagos-ai:masterfrom
v2yield:fused_moe_opt
Open

[PerfXLab] optimize fused_moe#2230
v2yield wants to merge 2 commits intoflagos-ai:masterfrom
v2yield:fused_moe_opt

Conversation

@v2yield
Copy link
Copy Markdown
Contributor

@v2yield v2yield commented Apr 3, 2026

PR Category

[ Operator]

Type of Change

[Performance Optimization]

Description

optimize fused_moe_impl on nvidia hopper

Performance

base on H100 GPU

Operator: fused_moe_gems_vs_vllm  Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.275072            0.281024               0.979          [torch.Size([1, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([1, 2]), torch.Size([1, 2])]
SUCCESS               0.958016            0.952736               1.006          [torch.Size([4, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([4, 2]), torch.Size([4, 2])]
SUCCESS               0.963904            0.961296               1.003          [torch.Size([16, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([16, 2]), torch.Size([16, 2])]
SUCCESS               0.961808            0.976432               0.985          [torch.Size([64, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([64, 2]), torch.Size([64, 2])]
SUCCESS               1.007296            1.008480               0.999          [torch.Size([128, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([128, 2]), torch.Size([128, 2])]
SUCCESS               1.282592            1.055136               1.216          [torch.Size([256, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([256, 2]), torch.Size([256, 2])]
SUCCESS               2.106400            1.381088               1.525          [torch.Size([512, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([512, 2]), torch.Size([512, 2])]
SUCCESS               0.280288            0.262176               1.069          [torch.Size([1, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([1, 8]), torch.Size([1, 8])]
SUCCESS               0.938016            0.930560               1.008          [torch.Size([4, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([4, 8]), torch.Size([4, 8])]
SUCCESS               2.866816            2.872032               0.998          [torch.Size([16, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([16, 8]), torch.Size([16, 8])]
SUCCESS               6.248544            6.253120               0.999          [torch.Size([64, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([64, 8]), torch.Size([64, 8])]
SUCCESS               7.165920            7.346976               0.975          [torch.Size([128, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([128, 8]), torch.Size([128, 8])]
SUCCESS               7.440224            7.504896               0.991          [torch.Size([256, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([256, 8]), torch.Size([256, 8])]


Operator: fused_moe_gems_vs_vllm  Performance Test (dtype=torch.float16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.275264            0.282432               0.975          [torch.Size([1, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([1, 2]), torch.Size([1, 2])]
SUCCESS               0.712064            0.737728               0.965          [torch.Size([4, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([4, 2]), torch.Size([4, 2])]
SUCCESS               0.963104            0.961392               1.002          [torch.Size([16, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([16, 2]), torch.Size([16, 2])]
SUCCESS               0.962496            0.976384               0.986          [torch.Size([64, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([64, 2]), torch.Size([64, 2])]
SUCCESS               1.007136            1.006176               1.001          [torch.Size([128, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([128, 2]), torch.Size([128, 2])]
SUCCESS               1.201312            1.058528               1.135          [torch.Size([256, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([256, 2]), torch.Size([256, 2])]
SUCCESS               2.162480            1.312544               1.648          [torch.Size([512, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([512, 2]), torch.Size([512, 2])]
SUCCESS               0.280384            0.261360               1.073          [torch.Size([1, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([1, 8]), torch.Size([1, 8])]
SUCCESS               0.936752            0.931104               1.006          [torch.Size([4, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([4, 8]), torch.Size([4, 8])]
SUCCESS               2.750000            2.762672               0.995          [torch.Size([16, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([16, 8]), torch.Size([16, 8])]
SUCCESS               6.354528            6.368480               0.998          [torch.Size([64, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([64, 8]), torch.Size([64, 8])]
SUCCESS               7.142176            7.323616               0.975          [torch.Size([128, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([128, 8]), torch.Size([128, 8])]
SUCCESS               7.438688            7.500704               0.992          [torch.Size([256, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([256, 8]), torch.Size([256, 8])]

PASSED
Operator: fused_moe_fp8_gems_vs_vllm  Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive)                                                                                                                                                                    
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail                                                                                                                                                                                    
-----------------------------------------------------------------------------------------------                                                                                                                                                                                   
SUCCESS               0.196384            0.237920               0.825          [torch.Size([1, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([8]), torch.Size([8])]                                     
SUCCESS               0.371072            0.392128               0.946          [torch.Size([4, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([4, 2]), torch.Size([4, 2]), torch.Size([8]), torch.Size([8])]                                     
SUCCESS               0.538976            0.531168               1.015          [torch.Size([16, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([16, 2]), torch.Size([16, 2]), torch.Size([8]), torch.Size([8])]                                  
SUCCESS               0.569040            0.557904               1.020          [torch.Size([64, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([64, 2]), torch.Size([64, 2]), torch.Size([8]), torch.Size([8])]                                  
SUCCESS               0.560016            0.569504               0.983          [torch.Size([128, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([128, 2]), torch.Size([128, 2]), torch.Size([8]), torch.Size([8])]
SUCCESS               0.722112            0.609104               1.186          [torch.Size([256, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([256, 2]), torch.Size([256, 2]), torch.Size([8]), torch.Size([8])]
SUCCESS               1.153328            0.732192               1.575          [torch.Size([512, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([512, 2]), torch.Size([512, 2]), torch.Size([8]), torch.Size([8])]
SUCCESS               0.194480            0.274400               0.709          [torch.Size([1, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([1, 8]), torch.Size([1, 8]), torch.Size([256]), torch.Size([256])]
SUCCESS               0.531888            0.523808               1.015          [torch.Size([4, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([4, 8]), torch.Size([4, 8]), torch.Size([256]), torch.Size([256])]
SUCCESS               1.509856            1.518080               0.995          [torch.Size([16, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([16, 8]), torch.Size([16, 8]), torch.Size([256]), torch.Size([256])]
SUCCESS               3.233792            3.245568               0.996          [torch.Size([64, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([64, 8]), torch.Size([64, 8]), torch.Size([256]), torch.Size([256])]
SUCCESS               3.617680            3.626288               0.998          [torch.Size([128, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([128, 8]), torch.Size([128, 8]), torch.Size([256]), torch.Size([256])]
SUCCESS               3.787456            3.797728               0.997          [torch.Size([256, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([256, 8]), torch.Size([256, 8]), torch.Size([256]), torch.Size([256])]

PASSED
Operator: fused_moe_int8_gems_vs_vllm  Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.228416            0.274240               0.833          [torch.Size([1, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               0.393392            0.398656               0.987          [torch.Size([4, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([4, 2]), torch.Size([4, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               0.512512            0.517584               0.990          [torch.Size([16, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([16, 2]), torch.Size([16, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               0.525792            0.533360               0.986          [torch.Size([64, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([64, 2]), torch.Size([64, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               0.745152            0.751760               0.991          [torch.Size([128, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([128, 2]), torch.Size([128, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               1.037760            1.044800               0.993          [torch.Size([256, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([256, 2]), torch.Size([256, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               1.539344            1.552144               0.992          [torch.Size([512, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([512, 2]), torch.Size([512, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               0.241312            0.276384               0.873          [torch.Size([1, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([1, 8]), torch.Size([1, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS               0.515472            0.508800               1.013          [torch.Size([4, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([4, 8]), torch.Size([4, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS               1.565984            1.568432               0.998          [torch.Size([16, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([16, 8]), torch.Size([16, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS               3.145424            3.149696               0.999          [torch.Size([64, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([64, 8]), torch.Size([64, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS               5.162640            5.164352               1.000          [torch.Size([128, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([128, 8]), torch.Size([128, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS               5.148848            5.153584               0.999          [torch.Size([256, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([256, 8]), torch.Size([256, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]

PASSED
Operator: fused_moe_int8_w8a16_gems_vs_bf16_deq  Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               8.606912            8.603200               1.000          [torch.Size([1, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               8.943488            8.940992               1.000          [torch.Size([4, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([4, 2]), torch.Size([4, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               9.169584            9.162432               1.001          [torch.Size([16, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([16, 2]), torch.Size([16, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               9.304640            9.299504               1.001          [torch.Size([64, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([64, 2]), torch.Size([64, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               9.341920            9.332944               1.001          [torch.Size([128, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([128, 2]), torch.Size([128, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               9.386560            9.383968               1.000          [torch.Size([256, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([256, 2]), torch.Size([256, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               9.670736            9.661936               1.001          [torch.Size([512, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([512, 2]), torch.Size([512, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS              66.781761           66.778816               1.000          [torch.Size([1, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([1, 8]), torch.Size([1, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS              67.461441           67.433693               1.000          [torch.Size([4, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([4, 8]), torch.Size([4, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS              69.452736           69.448929               1.000          [torch.Size([16, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([16, 8]), torch.Size([16, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS              72.788803           72.798431               1.000          [torch.Size([64, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([64, 8]), torch.Size([64, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS              73.949059           73.943619               1.000          [torch.Size([128, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([128, 8]), torch.Size([128, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS              74.016228           74.004128               1.000          [torch.Size([256, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([256, 8]), torch.Size([256, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]

PASSED
Operator: fused_moe_int4_w4a16_gems_vs_bf16_deq  Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               8.608480            8.606752               1.000          [torch.Size([1, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([1, 2]), torch.Size([1, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               8.944224            8.941472               1.000          [torch.Size([4, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([4, 2]), torch.Size([4, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               9.287616            9.285184               1.000          [torch.Size([16, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([16, 2]), torch.Size([16, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               9.305184            9.300704               1.000          [torch.Size([64, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([64, 2]), torch.Size([64, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               9.338000            9.329632               1.001          [torch.Size([128, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([128, 2]), torch.Size([128, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               9.383952            9.383088               1.000          [torch.Size([256, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([256, 2]), torch.Size([256, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS               9.661200            9.654784               1.001          [torch.Size([512, 4096]), torch.Size([8, 28672, 4096]), torch.Size([8, 4096, 14336]), torch.Size([512, 2]), torch.Size([512, 2]), torch.Size([8, 28672]), torch.Size([8, 4096])]
SUCCESS              66.784386           66.785759               1.000          [torch.Size([1, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([1, 8]), torch.Size([1, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS              67.446465           67.477730               1.000          [torch.Size([4, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([4, 8]), torch.Size([4, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS              69.381760           69.404419               1.000          [torch.Size([16, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([16, 8]), torch.Size([16, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS              72.859520           72.831459               1.000          [torch.Size([64, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([64, 8]), torch.Size([64, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS              73.862495           73.859329               1.000          [torch.Size([128, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([128, 8]), torch.Size([128, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]
SUCCESS              74.029663           74.111168               0.999          [torch.Size([256, 7168]), torch.Size([256, 4096, 7168]), torch.Size([256, 7168, 2048]), torch.Size([256, 8]), torch.Size([256, 8]), torch.Size([256, 4096]), torch.Size([256, 7168])]

PASSED

Copy link
Copy Markdown
Contributor

@tengqm tengqm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please touch the test case related to this change.
The CI job is triggered by test case changes, without which we cannot see how the change looks like.

When adapting the test cases, please strive to increase test coverage to the most.

We have encountered some compiler issues related to these fused operators.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants