Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
feb65b7
[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization
xiaoxi-wangfj Jul 3, 2025
a7de66c
[PyTorch/Common] Fuse permute+pad and unpermute+unpad support with_me…
xiaoxi-wangfj Dec 11, 2025
f550684
[PyTorch]format code
xiaoxi-wangfj Dec 11, 2025
6069277
[Common]perf expert_idx loaded once
xiaoxi-wangfj Dec 11, 2025
053abee
Merge branch 'main' into fused_perm_pad
xiaoxi-wangfj Dec 12, 2025
1ea08f7
fix: pad_offsets can be None
xiaoxi-wangfj Dec 17, 2025
ac12a91
Merge branch 'main' into fused_perm_pad
xiaoxi-wangfj Dec 17, 2025
230939c
add padding + merging probs bwd support. Not tested
tdophung Dec 11, 2025
f301462
Fix garbage initialized act grad
tdophung Dec 11, 2025
7ed584c
all test passing for jax permutation + pad
tdophung Dec 17, 2025
7998ce8
change tokens_per_experts APIs to num_out_tokens with conservative a…
tdophung Dec 17, 2025
dd5c72a
change test permutation to reduce test time
tdophung Dec 19, 2025
ce187b6
triggering PR refresh
tdophung Dec 19, 2025
7dc9ccb
format code
tdophung Dec 20, 2025
1fbe99c
Remove some tests cases from pytorch side. Add a separate toekn_dispa…
tdophung Dec 20, 2025
592f675
format code
tdophung Dec 20, 2025
1d43279
remove chance for inefficiency in moving between CPU and GPU, remove …
tdophung Dec 20, 2025
4169a4e
fix lint in jax
tdophung Dec 22, 2025
c619adf
account for both jax newer and older than version 0.8.2. Adjusted gpu…
tdophung Dec 22, 2025
405b341
format code
tdophung Dec 22, 2025
7cad5c5
fix typo
tdophung Dec 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading