Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support fuse layer norm grad for npu #10614

Merged
merged 13 commits into from
Jan 23, 2025
49 changes: 33 additions & 16 deletions oneflow/core/autograd/gradient_funcs/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ limitations under the License.
#include "oneflow/core/functional/functional.h"

namespace oneflow {

DEFINE_ENV_BOOL(ONEFLOW_USE_FUSE_LAYER_NORM_GRAD, false);

namespace one {

struct LayerNormCaptureState : public AutoGradCaptureState {
Expand Down Expand Up @@ -107,22 +110,36 @@ Maybe<void> LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple
std::shared_ptr<Tensor> mean = saved_tensors.at(ctx->mean_index);
std::shared_ptr<Tensor> inv_variance = saved_tensors.at(ctx->inv_variance_index);

if (ctx->has_affine) {
// Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance,
// Int64 begin_params_axis)
const auto& results =
JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis));
in_grads->at(1) = results->at(0); // For gamma.
in_grads->at(2) = results->at(1); // For beta.
}
if (ctx->x_requires_grad) {
if (ctx->scale) {
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma,
begin_norm_axis, ctx->epsilon));
} else {
in_grads->at(0) =
JUST(functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon));
if (EnvBool<ONEFLOW_USE_FUSE_LAYER_NORM_GRAD>()) {
// just for npu
CHECK(ctx->has_affine) << "LayerNorm::Apply must has_affine for NPU GPT2 test";
if (ctx->x_requires_grad) {
if (ctx->scale) {
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
*in_grads = *JUST(functional::FuseLayerNormGrad(
dy, x, mean, inv_variance, gamma, begin_norm_axis, begin_params_axis, ctx->epsilon));
} else {
UNIMPLEMENTED();
}
}
} else {
if (ctx->has_affine) {
// Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance,
// Int64 begin_params_axis)
const auto& results =
JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis));
in_grads->at(1) = results->at(0); // For gamma.
in_grads->at(2) = results->at(1); // For beta.
}
if (ctx->x_requires_grad) {
if (ctx->scale) {
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma,
begin_norm_axis, ctx->epsilon));
} else {
in_grads->at(0) = JUST(
functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon));
}
}
}
return Maybe<void>::Ok();
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,10 @@
signature: "Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Double epsilon) => LayerNormAffineGrad"
bind_python: False

- name: "fuse_layer_norm_grad"
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => FuseLayerNormGrad"
bind_python: False

- name: "layer_norm_param_grad"
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_params_axis) => LayerNormParamGrad"
bind_python: False
Expand Down
31 changes: 31 additions & 0 deletions oneflow/core/functional/impl/nn_grad_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,36 @@ class LayerNormAffineGradFunctor {
std::shared_ptr<OpExpr> op_;
};

class FuseLayerNormGradFunctor {
public:
FuseLayerNormGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("fuse_layer_norm_grad")
.Input("dy")
.Input("x")
.Input("mean")
.Input("inv_variance")
.Input("gamma")
.Output("dx")
.Output("gamma_diff")
.Output("beta_diff")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& mean,
const std::shared_ptr<one::Tensor>& inv_variance,
const std::shared_ptr<one::Tensor>& gamma,
const int64_t& begin_norm_axis, const int64_t& begin_params_axis,
const double& epsilon) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon");
attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon);
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, x, mean, inv_variance, gamma}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class LayerNormParamGradFunctor {
public:
LayerNormParamGradFunctor() {
Expand Down Expand Up @@ -1707,6 +1737,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::LayerNormGradFunctor>("LayerNormGrad");
m.add_functor<impl::LayerNormAffineGradFunctor>("LayerNormAffineGrad");
m.add_functor<impl::LayerNormParamGradFunctor>("LayerNormParamGrad");
m.add_functor<impl::FuseLayerNormGradFunctor>("FuseLayerNormGrad");
m.add_functor<impl::GroupNormGradFunctor>("GroupNormGrad");
m.add_functor<impl::GroupNormParamGradFunctor>("GroupNormParamGrad");
m.add_functor<impl::BroadcastMatmulGradBFunctor>("BroadcastMatmulGradB");
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/job_rewriter/auto_mixed_precision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "inv_variance", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "inv_variance", 0)
REGISTER_NO_CAST_REGISTRY("fuse_layer_norm_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("fuse_layer_norm_grad", "inv_variance", 0)

} // namespace

Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
"layer_norm",
"layer_norm_param_grad",
"layer_norm_grad",
"fuse_layer_norm_grad",
"skip_layer_norm",
"rms_norm",
"rms_norm_grad",
Expand Down
29 changes: 29 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7071,6 +7071,35 @@ def OneFlow_LayerNormGradOp : OneFlow_BaseOp<"layer_norm_grad", [NoMemoryEffect,
let has_data_type_infer_fn = 1;
}

def OneFlow_FuseLayerNormGradOp : OneFlow_BaseOp<"fuse_layer_norm_grad", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$dy,
OneFlow_Tensor:$x,
OneFlow_Tensor:$mean,
OneFlow_Tensor:$inv_variance,
Optional<OneFlow_Tensor>:$gamma,
Optional<OneFlow_Tensor>:$_add_to_output
);
let output = (outs
OneFlow_Tensor:$dx,
OneFlow_Tensor:$gamma_diff,
OneFlow_Tensor:$beta_diff
);
let attrs = (ins
DefaultValuedAttr<SI64Attr, "0">:$begin_norm_axis,
DefaultValuedAttr<SI64Attr, "0">:$begin_params_axis,
DefaultValuedAttr<F64Attr, "0.">:$epsilon
);
let trait_attrs = (ins
DenseI32ArrayAttr:$operand_segment_sizes,
DenseI32ArrayAttr:$result_segment_sizes
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_LayerNormParamGradOp : OneFlow_BaseOp<"layer_norm_param_grad", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$dy,
Expand Down
20 changes: 20 additions & 0 deletions oneflow/user/kernels/layer_norm_cpu_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,26 @@ class LayerNormGradCpuKernel final : public user_op::OpKernel {
REGISTER_LAYER_NORM_GRAD_CPU_KERNEL(float)
REGISTER_LAYER_NORM_GRAD_CPU_KERNEL(double)

template<typename T>
class FuseLayerNormGradCpuKernel final : public user_op::OpKernel {
public:
FuseLayerNormGradCpuKernel() = default;
~FuseLayerNormGradCpuKernel() = default;

private:
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); };
};

#define REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(dtype) \
REGISTER_USER_KERNEL("fuse_layer_norm_grad") \
.SetCreateFn<LayerNormGradCpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value));

REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(float)
REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(double)

template<typename T>
class LayerNormParamGradCpuKernel final : public user_op::OpKernel {
public:
Expand Down
125 changes: 125 additions & 0 deletions oneflow/user/ops/layer_norm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,4 +268,129 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) {
return Maybe<void>::Ok();
}

/* static */ Maybe<void> FuseLayerNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0);
const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0);
const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0);
user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0);
CHECK_EQ_OR_RETURN(dy.shape(), x.shape()) << "dy and x shapes should be equal.";
const int64_t begin_norm_axis = ctx->Attr<int64_t>("begin_norm_axis");
CHECK_GT_OR_RETURN(begin_norm_axis, 0) << "begin_norm_axis must be greater than 0.";
const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis);
CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape) << "mean shape must match bn_param_shape.";
CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape)
<< "inv_variance shape must match bn_param_shape.";
dx->set_shape(dy.shape());
dx->set_is_dynamic(dy.is_dynamic());
if (ctx->has_input("_add_to_output", 0)) {
const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0);
CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape())
<< "add_to_output shape must match dx shape.";
}

auto has_tensor = [ctx](const std::string& bn) -> bool {
bool ret = false;
for (const auto& t : ctx->inputs()) {
if (bn == t.first) { return true; }
}
for (const auto& t : ctx->outputs()) {
if (bn == t.first) { return true; }
}
return ret;
};
const int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis");
const bool has_beta_diff = has_tensor("beta_diff");
const bool has_gamma_diff = has_tensor("gamma_diff");
CHECK_GE_OR_RETURN(begin_params_axis, 1)
<< "begin_params_axis must be greater than or equal to 1.";
CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes())
<< "begin_params_axis must be less than the number of axes in dy shape.";
DimVector param_shape_dim_vec;
param_shape_dim_vec.insert(param_shape_dim_vec.end(),
dy.shape().dim_vec().cbegin() + begin_params_axis,
dy.shape().dim_vec().cend());
const Shape param_shape(param_shape_dim_vec);
if (has_beta_diff) {
user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0);
beta_diff->set_shape(param_shape);
}
if (has_gamma_diff) {
user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0);
gamma_diff->set_shape(param_shape);
}
return Maybe<void>::Ok();
}

/*static*/ Maybe<void> FuseLayerNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {
return InferLogicalTensorDesc(ctx);
}

/* static */ Maybe<void> FuseLayerNormGradOp::GetSbp(user_op::SbpContext* ctx) {
std::vector<user_op::OpArg> broadcast_args;
if (ctx->user_op_conf().has_input("gamma", 0)) { broadcast_args.emplace_back("gamma", 0); }
int64_t begin_norm_axis = ctx->Attr<int64_t>("begin_norm_axis");
int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis");
CHECK_EQ(begin_norm_axis, begin_params_axis)
<< "begin_norm_axis and begin_params_axis must be equal, but got " << begin_norm_axis
<< " and " << begin_params_axis;
for (int i = 0; i < begin_norm_axis; ++i) {
ctx->NewBuilder()
.Split(ctx->inputs(), i)
.Split(user_op::OpArg("dx", 0), i)
.PartialSum(user_op::OpArg("gamma_diff", 0))
.PartialSum(user_op::OpArg("beta_diff", 0))
.Broadcast(broadcast_args)
.Build();
}
return Maybe<void>::Ok();
}

/* static */ Maybe<void> FuseLayerNormGradOp::InferDataType(user_op::InferContext* ctx) {
const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0);
CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type())
<< "InferDataType Failed. Expected " << DataType_Name(x.data_type()) << ", but got "
<< DataType_Name(dy.data_type());
const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0);
const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0);
DataType bn_param_data_type = InferBnParamDataType(x.data_type());
CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type)
<< "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got "
<< DataType_Name(mean.data_type());
CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type)
<< "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got "
<< DataType_Name(inv_variance.data_type());
user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0);
dx->set_data_type(dy.data_type());
if (ctx->has_input("_add_to_output", 0)) {
const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0);
CHECK_EQ_OR_RETURN(add_to_output.data_type(), dx->data_type())
<< "InferDataType Failed. Expected " << DataType_Name(dx->data_type()) << ", but got "
<< DataType_Name(add_to_output.data_type());
}

auto has_tensor = [ctx](const std::string& bn) -> bool {
bool ret = false;
for (auto& t : ctx->inputs()) {
if (bn == t.first) { return true; }
}
for (auto& t : ctx->outputs()) {
if (bn == t.first) { return true; }
}
return ret;
};
const bool has_beta_diff = has_tensor("beta_diff");
const bool has_gamma_diff = has_tensor("gamma_diff");
if (has_beta_diff) {
user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0);
beta_diff->set_data_type(dy.data_type());
}
if (has_gamma_diff) {
user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0);
gamma_diff->set_data_type(dy.data_type());
}
return Maybe<void>::Ok();
}

} // namespace oneflow
Loading