Skip to content

Commit

Permalink
Merge branch 'master' into dev_refactor_xccl_primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
Flowingsun007 authored Jan 26, 2025
2 parents 65e3046 + cb699cd commit 0a883a7
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 16 deletions.
49 changes: 33 additions & 16 deletions oneflow/core/autograd/gradient_funcs/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ limitations under the License.
#include "oneflow/core/functional/functional.h"

namespace oneflow {

DEFINE_ENV_BOOL(ONEFLOW_USE_FUSE_LAYER_NORM_GRAD, false);

namespace one {

struct LayerNormCaptureState : public AutoGradCaptureState {
Expand Down Expand Up @@ -107,22 +110,36 @@ Maybe<void> LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple
std::shared_ptr<Tensor> mean = saved_tensors.at(ctx->mean_index);
std::shared_ptr<Tensor> inv_variance = saved_tensors.at(ctx->inv_variance_index);

if (ctx->has_affine) {
// Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance,
// Int64 begin_params_axis)
const auto& results =
JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis));
in_grads->at(1) = results->at(0); // For gamma.
in_grads->at(2) = results->at(1); // For beta.
}
if (ctx->x_requires_grad) {
if (ctx->scale) {
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma,
begin_norm_axis, ctx->epsilon));
} else {
in_grads->at(0) =
JUST(functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon));
if (EnvBool<ONEFLOW_USE_FUSE_LAYER_NORM_GRAD>()) {
// just for npu
CHECK(ctx->has_affine) << "LayerNorm::Apply must has_affine for NPU GPT2 test";
if (ctx->x_requires_grad) {
if (ctx->scale) {
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
*in_grads = *JUST(functional::FuseLayerNormGrad(
dy, x, mean, inv_variance, gamma, begin_norm_axis, begin_params_axis, ctx->epsilon));
} else {
UNIMPLEMENTED();
}
}
} else {
if (ctx->has_affine) {
// Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance,
// Int64 begin_params_axis)
const auto& results =
JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis));
in_grads->at(1) = results->at(0); // For gamma.
in_grads->at(2) = results->at(1); // For beta.
}
if (ctx->x_requires_grad) {
if (ctx->scale) {
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma,
begin_norm_axis, ctx->epsilon));
} else {
in_grads->at(0) = JUST(
functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon));
}
}
}
return Maybe<void>::Ok();
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,10 @@
signature: "Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Double epsilon) => LayerNormAffineGrad"
bind_python: False

- name: "fuse_layer_norm_grad"
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => FuseLayerNormGrad"
bind_python: False

- name: "layer_norm_param_grad"
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_params_axis) => LayerNormParamGrad"
bind_python: False
Expand Down
31 changes: 31 additions & 0 deletions oneflow/core/functional/impl/nn_grad_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,36 @@ class LayerNormAffineGradFunctor {
std::shared_ptr<OpExpr> op_;
};

class FuseLayerNormGradFunctor {
public:
FuseLayerNormGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("fuse_layer_norm_grad")
.Input("dy")
.Input("x")
.Input("mean")
.Input("inv_variance")
.Input("gamma")
.Output("dx")
.Output("gamma_diff")
.Output("beta_diff")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& mean,
const std::shared_ptr<one::Tensor>& inv_variance,
const std::shared_ptr<one::Tensor>& gamma,
const int64_t& begin_norm_axis, const int64_t& begin_params_axis,
const double& epsilon) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon");
attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon);
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, x, mean, inv_variance, gamma}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class LayerNormParamGradFunctor {
public:
LayerNormParamGradFunctor() {
Expand Down Expand Up @@ -1707,6 +1737,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::LayerNormGradFunctor>("LayerNormGrad");
m.add_functor<impl::LayerNormAffineGradFunctor>("LayerNormAffineGrad");
m.add_functor<impl::LayerNormParamGradFunctor>("LayerNormParamGrad");
m.add_functor<impl::FuseLayerNormGradFunctor>("FuseLayerNormGrad");
m.add_functor<impl::GroupNormGradFunctor>("GroupNormGrad");
m.add_functor<impl::GroupNormParamGradFunctor>("GroupNormParamGrad");
m.add_functor<impl::BroadcastMatmulGradBFunctor>("BroadcastMatmulGradB");
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/job_rewriter/auto_mixed_precision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "inv_variance", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "inv_variance", 0)
REGISTER_NO_CAST_REGISTRY("fuse_layer_norm_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("fuse_layer_norm_grad", "inv_variance", 0)

} // namespace

Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
"layer_norm",
"layer_norm_param_grad",
"layer_norm_grad",
"fuse_layer_norm_grad",
"skip_layer_norm",
"rms_norm",
"rms_norm_grad",
Expand Down
29 changes: 29 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7071,6 +7071,35 @@ def OneFlow_LayerNormGradOp : OneFlow_BaseOp<"layer_norm_grad", [NoMemoryEffect,
let has_data_type_infer_fn = 1;
}

def OneFlow_FuseLayerNormGradOp : OneFlow_BaseOp<"fuse_layer_norm_grad", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$dy,
OneFlow_Tensor:$x,
OneFlow_Tensor:$mean,
OneFlow_Tensor:$inv_variance,
Optional<OneFlow_Tensor>:$gamma,
Optional<OneFlow_Tensor>:$_add_to_output
);
let output = (outs
OneFlow_Tensor:$dx,
OneFlow_Tensor:$gamma_diff,
OneFlow_Tensor:$beta_diff
);
let attrs = (ins
DefaultValuedAttr<SI64Attr, "0">:$begin_norm_axis,
DefaultValuedAttr<SI64Attr, "0">:$begin_params_axis,
DefaultValuedAttr<F64Attr, "0.">:$epsilon
);
let trait_attrs = (ins
DenseI32ArrayAttr:$operand_segment_sizes,
DenseI32ArrayAttr:$result_segment_sizes
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_LayerNormParamGradOp : OneFlow_BaseOp<"layer_norm_param_grad", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$dy,
Expand Down
20 changes: 20 additions & 0 deletions oneflow/user/kernels/layer_norm_cpu_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,26 @@ class LayerNormGradCpuKernel final : public user_op::OpKernel {
REGISTER_LAYER_NORM_GRAD_CPU_KERNEL(float)
REGISTER_LAYER_NORM_GRAD_CPU_KERNEL(double)

template<typename T>
class FuseLayerNormGradCpuKernel final : public user_op::OpKernel {
public:
FuseLayerNormGradCpuKernel() = default;
~FuseLayerNormGradCpuKernel() = default;

private:
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); };
};

#define REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(dtype) \
REGISTER_USER_KERNEL("fuse_layer_norm_grad") \
.SetCreateFn<LayerNormGradCpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value));

REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(float)
REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(double)

template<typename T>
class LayerNormParamGradCpuKernel final : public user_op::OpKernel {
public:
Expand Down
125 changes: 125 additions & 0 deletions oneflow/user/ops/layer_norm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,4 +268,129 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) {
return Maybe<void>::Ok();
}

/* static */ Maybe<void> FuseLayerNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0);
const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0);
const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0);
user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0);
CHECK_EQ_OR_RETURN(dy.shape(), x.shape()) << "dy and x shapes should be equal.";
const int64_t begin_norm_axis = ctx->Attr<int64_t>("begin_norm_axis");
CHECK_GT_OR_RETURN(begin_norm_axis, 0) << "begin_norm_axis must be greater than 0.";
const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis);
CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape) << "mean shape must match bn_param_shape.";
CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape)
<< "inv_variance shape must match bn_param_shape.";
dx->set_shape(dy.shape());
dx->set_is_dynamic(dy.is_dynamic());
if (ctx->has_input("_add_to_output", 0)) {
const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0);
CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape())
<< "add_to_output shape must match dx shape.";
}

auto has_tensor = [ctx](const std::string& bn) -> bool {
bool ret = false;
for (const auto& t : ctx->inputs()) {
if (bn == t.first) { return true; }
}
for (const auto& t : ctx->outputs()) {
if (bn == t.first) { return true; }
}
return ret;
};
const int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis");
const bool has_beta_diff = has_tensor("beta_diff");
const bool has_gamma_diff = has_tensor("gamma_diff");
CHECK_GE_OR_RETURN(begin_params_axis, 1)
<< "begin_params_axis must be greater than or equal to 1.";
CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes())
<< "begin_params_axis must be less than the number of axes in dy shape.";
DimVector param_shape_dim_vec;
param_shape_dim_vec.insert(param_shape_dim_vec.end(),
dy.shape().dim_vec().cbegin() + begin_params_axis,
dy.shape().dim_vec().cend());
const Shape param_shape(param_shape_dim_vec);
if (has_beta_diff) {
user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0);
beta_diff->set_shape(param_shape);
}
if (has_gamma_diff) {
user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0);
gamma_diff->set_shape(param_shape);
}
return Maybe<void>::Ok();
}

/*static*/ Maybe<void> FuseLayerNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {
return InferLogicalTensorDesc(ctx);
}

/* static */ Maybe<void> FuseLayerNormGradOp::GetSbp(user_op::SbpContext* ctx) {
std::vector<user_op::OpArg> broadcast_args;
if (ctx->user_op_conf().has_input("gamma", 0)) { broadcast_args.emplace_back("gamma", 0); }
int64_t begin_norm_axis = ctx->Attr<int64_t>("begin_norm_axis");
int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis");
CHECK_EQ(begin_norm_axis, begin_params_axis)
<< "begin_norm_axis and begin_params_axis must be equal, but got " << begin_norm_axis
<< " and " << begin_params_axis;
for (int i = 0; i < begin_norm_axis; ++i) {
ctx->NewBuilder()
.Split(ctx->inputs(), i)
.Split(user_op::OpArg("dx", 0), i)
.PartialSum(user_op::OpArg("gamma_diff", 0))
.PartialSum(user_op::OpArg("beta_diff", 0))
.Broadcast(broadcast_args)
.Build();
}
return Maybe<void>::Ok();
}

/* static */ Maybe<void> FuseLayerNormGradOp::InferDataType(user_op::InferContext* ctx) {
const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0);
CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type())
<< "InferDataType Failed. Expected " << DataType_Name(x.data_type()) << ", but got "
<< DataType_Name(dy.data_type());
const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0);
const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0);
DataType bn_param_data_type = InferBnParamDataType(x.data_type());
CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type)
<< "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got "
<< DataType_Name(mean.data_type());
CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type)
<< "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got "
<< DataType_Name(inv_variance.data_type());
user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0);
dx->set_data_type(dy.data_type());
if (ctx->has_input("_add_to_output", 0)) {
const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0);
CHECK_EQ_OR_RETURN(add_to_output.data_type(), dx->data_type())
<< "InferDataType Failed. Expected " << DataType_Name(dx->data_type()) << ", but got "
<< DataType_Name(add_to_output.data_type());
}

auto has_tensor = [ctx](const std::string& bn) -> bool {
bool ret = false;
for (auto& t : ctx->inputs()) {
if (bn == t.first) { return true; }
}
for (auto& t : ctx->outputs()) {
if (bn == t.first) { return true; }
}
return ret;
};
const bool has_beta_diff = has_tensor("beta_diff");
const bool has_gamma_diff = has_tensor("gamma_diff");
if (has_beta_diff) {
user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0);
beta_diff->set_data_type(dy.data_type());
}
if (has_gamma_diff) {
user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0);
gamma_diff->set_data_type(dy.data_type());
}
return Maybe<void>::Ok();
}

} // namespace oneflow

0 comments on commit 0a883a7

Please sign in to comment.