Skip to content

Commit 3679d76

Browse files
support noaux eplb
1 parent 6fa3410 commit 3679d76

File tree

8 files changed

+620
-23
lines changed

8 files changed

+620
-23
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,19 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
647647
bool renormalize,
648648
float routed_scaling_factor);
649649

650+
std::vector<paddle::Tensor> NoauxTcRedundant(
651+
paddle::Tensor& scores,
652+
paddle::Tensor& scores_with_bias,
653+
paddle::Tensor& expert_id_to_ep_rank_array,
654+
paddle::Tensor& expert_in_rank_num_list,
655+
paddle::Tensor& tokens_per_expert_stats_list,
656+
int n_group,
657+
int topk_group,
658+
int topk,
659+
bool renormalize,
660+
float routed_scaling_factor,
661+
int redundant_ep_rank_num_plus_one);
662+
650663
#ifdef ENABLE_FP8
651664
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
652665
const paddle::Tensor& x,
@@ -1485,6 +1498,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
14851498

14861499
m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
14871500

1501+
m.def("noaux_tc_redunant",
1502+
&NoauxTcRedundant,
1503+
"noaux_tc_redundant for MoE compute");
1504+
14881505
#ifdef ENABLE_FP8
14891506
m.def("cutlass_fp8_fp8_half_gemm_fused",
14901507
&cutlass_fp8_fp8_half_gemm_func,
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <algorithm>
18+
#include <optional>
19+
20+
#include "helper.h"
21+
#include "noauxtc_kernel.h"
22+
23+
std::vector<paddle::Tensor> NoauxTcRedundant(
24+
paddle::Tensor& scores,
25+
paddle::Tensor& scores_with_bias,
26+
paddle::Tensor& expert_id_to_ep_rank_array,
27+
paddle::Tensor& expert_in_rank_num_list,
28+
paddle::Tensor& tokens_per_expert_stats_list,
29+
int n_group,
30+
int topk_group,
31+
int topk,
32+
bool renormalize,
33+
float routed_scaling_factor,
34+
int redundant_ep_rank_num_plus_one) {
35+
auto input_shape = scores_with_bias.shape();
36+
PD_CHECK(input_shape.size() == 2);
37+
int64_t num_tokens = input_shape[0];
38+
int64_t num_experts = input_shape[1];
39+
auto input_type = scores_with_bias.dtype();
40+
auto place = scores_with_bias.place();
41+
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
42+
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
43+
auto topk_indices =
44+
paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place);
45+
auto stream = scores_with_bias.stream();
46+
47+
invokeNoAuxTcRedundant<float, int64_t>(
48+
reinterpret_cast<float*>(scores.data<float>()),
49+
reinterpret_cast<float*>(group_scores.data<float>()),
50+
reinterpret_cast<float*>(topk_values.data<float>()),
51+
reinterpret_cast<int64_t*>(topk_indices.data<int64_t>()),
52+
reinterpret_cast<float*>(scores_with_bias.data<float>()),
53+
reinterpret_cast<int*>(expert_id_to_ep_rank_array.data<int>()),
54+
reinterpret_cast<int*>(expert_in_rank_num_list.data<int>()),
55+
reinterpret_cast<int*>(tokens_per_expert_stats_list.data<int>()),
56+
num_tokens,
57+
num_experts,
58+
n_group,
59+
topk_group,
60+
topk,
61+
renormalize,
62+
routed_scaling_factor,
63+
redundant_ep_rank_num_plus_one,
64+
stream);
65+
66+
return {scores, topk_values, topk_indices};
67+
}
68+
69+
std::vector<paddle::DataType> NoauxTcRedundantInferDtype(
70+
const paddle::DataType& scores_dtype,
71+
const paddle::DataType& scores_with_bias_dtype) {
72+
return {scores_dtype, scores_dtype, paddle::DataType::INT64};
73+
}
74+
75+
std::vector<std::vector<int64_t>> NoauxTcRedundantInferShape(
76+
const std::vector<int64_t>& scores_shape,
77+
const std::vector<int64_t>&,
78+
const int topk) {
79+
auto num_tokens = scores_shape[0];
80+
auto topk_values_shape = std::vector<int64_t>{num_tokens, topk};
81+
auto topk_indices_shape = std::vector<int64_t>{num_tokens, topk};
82+
return {scores_shape, topk_values_shape, topk_indices_shape};
83+
}
84+
85+
PD_BUILD_STATIC_OP(noaux_tc_redundant)
86+
.Inputs({"scores",
87+
"scores_with_bias",
88+
"expert_id_to_ep_rank_array",
89+
"expert_in_rank_num_list",
90+
"tokens_per_expert_stats_list"})
91+
.Outputs({"output_tensor",
92+
"topk_values",
93+
"topk_indices",
94+
"tokens_per_expert_stats_list_out"})
95+
.Attrs({"n_group: int",
96+
"topk_group: int",
97+
"topk:int",
98+
"renormalize: bool",
99+
"routed_scaling_factor: float",
100+
"redundant_ep_rank_num_plus_one:int"})
101+
.SetInplaceMap({{"tokens_per_expert_stats_list",
102+
"tokens_per_expert_stats_list_out"}})
103+
.SetKernelFn(PD_KERNEL(NoauxTcRedundant))
104+
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcRedundantInferShape))
105+
.SetInferDtypeFn(PD_INFER_DTYPE(NoauxTcRedundantInferDtype));

0 commit comments

Comments
 (0)