Skip to content

Commit efa5971

Browse files
authored
less time (PaddlePaddle#70442)
1 parent b087e48 commit efa5971

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

paddle/phi/infermeta/multiary.cc

+8-11
Original file line numberDiff line numberDiff line change
@@ -4734,18 +4734,13 @@ void RmsNormInferMeta(const MetaTensor& x,
47344734
MetaTensor* out,
47354735
MetaTensor* residual_out,
47364736
MetaTensor* inv_var) {
4737-
std::vector<int64_t> x_dims_vec = common::vectorize(x.dims());
4738-
auto x_dims_size = x_dims_vec.size();
4737+
size_t x_dims_size = x.dims().size();
47394738

47404739
size_t normalized_dims = 1;
47414740
for (size_t i = begin_norm_axis; i < x_dims_size; ++i) {
4742-
normalized_dims *= x_dims_vec[i];
4741+
normalized_dims *= x.dims().at(i);
47434742
}
47444743

4745-
std::vector<int64_t> inv_var_dims;
4746-
for (size_t i = size_t(0); i < static_cast<size_t>(begin_norm_axis); i++) {
4747-
inv_var_dims.push_back(x_dims_vec[i]);
4748-
}
47494744
PADDLE_ENFORCE_EQ(normalized_dims,
47504745
norm_weight.dims()[0],
47514746
common::errors::InvalidArgument(
@@ -4756,9 +4751,7 @@ void RmsNormInferMeta(const MetaTensor& x,
47564751
normalized_dims,
47574752
norm_weight.dims()[0]));
47584753

4759-
auto out_dims = common::make_ddim(x_dims_vec);
4760-
4761-
out->set_dims(out_dims);
4754+
out->set_dims(x.dims());
47624755

47634756
if (quant_scale > 0) {
47644757
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
@@ -4774,12 +4767,16 @@ void RmsNormInferMeta(const MetaTensor& x,
47744767

47754768
if (inv_var != nullptr) {
47764769
inv_var->set_dtype(phi::DataType::FLOAT32);
4770+
std::vector<int64_t> inv_var_dims;
4771+
for (size_t i = size_t(0); i < static_cast<size_t>(begin_norm_axis); i++) {
4772+
inv_var_dims.push_back(x.dims().at(i));
4773+
}
47774774
inv_var->set_dims(common::make_ddim(inv_var_dims));
47784775
inv_var->set_layout(x.layout());
47794776
}
47804777

47814778
if (residual != nullptr) {
4782-
residual_out->set_dims(out_dims);
4779+
residual_out->set_dims(x.dims());
47834780
residual_out->set_dtype(x.dtype());
47844781
residual_out->set_layout(x.layout());
47854782
residual_out->share_lod(x);

0 commit comments

Comments
 (0)