Skip to content

Commit 1048c1b

Browse files
committed
issue/421: 适配 rmsnorm 测例修改,支持 bf16 和 f16数据类型 weights
1 parent c0d1b0d commit 1048c1b

File tree

4 files changed

+11
-7
lines changed

4 files changed

+11
-7
lines changed

src/infiniop/ops/causal_softmax/kunlun/kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ __device__ void causalSoftmaxBlock(
5454
// Apply softmax
5555
for (size_t col = core_id(); col < width; col += BLOCK_SIZE) {
5656
if (sum_ != 0) {
57-
y[col] = to<Tdata>(to<Tcompute>(y[col]) / sum_);
57+
y[col] = Tdata(Tcompute(y[col]) / sum_);
5858
} else {
5959
y[col] = Tdata(0);
6060
}

src/infiniop/ops/rms_norm/kunlun/kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ __device__ void rmsnormBlock(
2727
for (size_t i = core_id(); i < dim; i += BLOCK_SIZE) {
2828
Tdata xi = x[i];
2929
Tweight wi = w[i];
30-
y[i] = static_cast<Tdata>(to<Tcompute>(xi) * to<Tcompute>(wi) * rms);
30+
y[i] = Tdata(Tcompute(xi) * Tcompute(wi) * rms);
3131
}
3232
sync_cluster();
3333
}

src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.xpu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,14 @@ infiniStatus_t launchKernel(
9595

9696
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
9797
LAUNCH_KERNEL(half, half, float);
98+
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
99+
LAUNCH_KERNEL(half, bfloat16_t, float);
98100
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
99101
LAUNCH_KERNEL(half, float, float);
100102
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
101103
LAUNCH_KERNEL(bfloat16_t, bfloat16_t, float);
104+
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
105+
LAUNCH_KERNEL(bfloat16_t, half, float);
102106
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
103107
LAUNCH_KERNEL(bfloat16_t, float, float);
104108
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {

src/infiniop/reduce/kunlun/reduce_kunlun.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ __device__ inline Tcompute sumSquared(__shared_ptr__ const Tdata *data_ptr, size
1414

1515
for (size_t i = core_id(); i < count; i += BLOCK_SIZE) {
1616
Tdata xi = data_ptr[i];
17-
ss += to<Tcompute>(xi) * to<Tcompute>(xi);
17+
ss += Tcompute(xi) * Tcompute(xi);
1818
}
1919

2020
__shared__ Tcompute temp_storage;
2121
if (core_id() == 0) {
22-
temp_storage = to<Tcompute>(0.f);
22+
temp_storage = Tcompute(0.f);
2323
}
2424
sync_cluster();
2525

@@ -36,12 +36,12 @@ __device__ inline Tcompute sum(__shared_ptr__ const Tdata *data_ptr, size_t coun
3636

3737
for (size_t i = core_id(); i < count; i += BLOCK_SIZE) {
3838
Tdata xi = data_ptr[i];
39-
ss += to<Tcompute>(xi);
39+
ss += Tcompute(xi);
4040
}
4141

4242
__shared__ Tcompute temp_storage;
4343
if (core_id() == 0) {
44-
temp_storage = to<Tcompute>(0.f);
44+
temp_storage = Tcompute(0.f);
4545
}
4646
sync_cluster();
4747

@@ -58,7 +58,7 @@ __device__ inline Tdata max(__shared_ptr__ const Tdata *data_ptr, size_t count)
5858

5959
for (size_t i = core_id(); i < count; i += BLOCK_SIZE) {
6060
Tdata xi = data_ptr[i];
61-
max_val = fmax(max_val, to<Tdata>(xi));
61+
max_val = fmax(max_val, Tdata(xi));
6262
}
6363

6464
__shared__ Tdata temp_storage;

0 commit comments

Comments
 (0)