@@ -4734,18 +4734,13 @@ void RmsNormInferMeta(const MetaTensor& x,
4734
4734
MetaTensor* out,
4735
4735
MetaTensor* residual_out,
4736
4736
MetaTensor* inv_var) {
4737
- std::vector<int64_t > x_dims_vec = common::vectorize (x.dims ());
4738
- auto x_dims_size = x_dims_vec.size ();
4737
+ size_t x_dims_size = x.dims ().size ();
4739
4738
4740
4739
size_t normalized_dims = 1 ;
4741
4740
for (size_t i = begin_norm_axis; i < x_dims_size; ++i) {
4742
- normalized_dims *= x_dims_vec[i] ;
4741
+ normalized_dims *= x. dims (). at (i) ;
4743
4742
}
4744
4743
4745
- std::vector<int64_t > inv_var_dims;
4746
- for (size_t i = size_t (0 ); i < static_cast <size_t >(begin_norm_axis); i++) {
4747
- inv_var_dims.push_back (x_dims_vec[i]);
4748
- }
4749
4744
PADDLE_ENFORCE_EQ (normalized_dims,
4750
4745
norm_weight.dims ()[0 ],
4751
4746
common::errors::InvalidArgument (
@@ -4756,9 +4751,7 @@ void RmsNormInferMeta(const MetaTensor& x,
4756
4751
normalized_dims,
4757
4752
norm_weight.dims ()[0 ]));
4758
4753
4759
- auto out_dims = common::make_ddim (x_dims_vec);
4760
-
4761
- out->set_dims (out_dims);
4754
+ out->set_dims (x.dims ());
4762
4755
4763
4756
if (quant_scale > 0 ) {
4764
4757
if (fabs (quant_max_bound - 127 .0f ) < 0.000001 ) {
@@ -4774,12 +4767,16 @@ void RmsNormInferMeta(const MetaTensor& x,
4774
4767
4775
4768
if (inv_var != nullptr ) {
4776
4769
inv_var->set_dtype (phi::DataType::FLOAT32);
4770
+ std::vector<int64_t > inv_var_dims;
4771
+ for (size_t i = size_t (0 ); i < static_cast <size_t >(begin_norm_axis); i++) {
4772
+ inv_var_dims.push_back (x.dims ().at (i));
4773
+ }
4777
4774
inv_var->set_dims (common::make_ddim (inv_var_dims));
4778
4775
inv_var->set_layout (x.layout ());
4779
4776
}
4780
4777
4781
4778
if (residual != nullptr ) {
4782
- residual_out->set_dims (out_dims );
4779
+ residual_out->set_dims (x. dims () );
4783
4780
residual_out->set_dtype (x.dtype ());
4784
4781
residual_out->set_layout (x.layout ());
4785
4782
residual_out->share_lod (x);
0 commit comments