Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support forward-mode AD in eager mode #71075

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions paddle/fluid/eager/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ set(eager_deps
eager_nan_inf_utils
grad_node_info
grad_tensor_holder
custom_operator_node)
custom_operator_node
fwd_utils)

if(WITH_GPU OR WITH_ROCM)
set(eager_deps ${eager_deps} phi_gpu)
endif()

if(NOT (NOT WITH_PYTHON AND ON_INFER))
set(eager_deps ${eager_deps} accumulation_node prim_utils)
set(eager_deps ${eager_deps} accumulation_node prim_utils fwd_utils)
endif()

set(fluid_deps tracer layer proto_desc operator op_registry variable_helper)
Expand All @@ -32,6 +33,7 @@ endif()

add_subdirectory(api)
add_subdirectory(custom_operator)
add_subdirectory(fwd)
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
add_subdirectory(accumulation)
add_subdirectory(pylayer)
Expand All @@ -42,7 +44,7 @@ if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
add_dependencies(grad_tensor_holder eager_codegen)
cc_library(
backward
SRCS backward.cc
SRCS backward.cc fwd/forward.cc fwd/forward_grad.cc fwd/forward_grad.h
DEPS grad_tensor_holder utils autograd_meta grad_node_info phi common)
endif()

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/eager/api/manual/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
set(eager_manual_nodes
${eager_manual_nodes}
PARENT_SCOPE)
set(eager_manual_jvps
${eager_manual_jvps}
PARENT_SCOPE)
endif()
4 changes: 4 additions & 0 deletions paddle/fluid/eager/api/manual/eager_manual/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(forward_grad)
add_subdirectory(forwards)
add_subdirectory(nodes)
set(eager_manual_functions
Expand All @@ -6,3 +7,6 @@ set(eager_manual_functions
set(eager_manual_nodes
${eager_manual_nodes}
PARENT_SCOPE)
set(eager_manual_jvps
${eager_manual_jvps}
PARENT_SCOPE)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(eager_manual_jvps
${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/manual/eager_manual/forward_grad/manual_jvp.cc
PARENT_SCOPE)
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Manual jvp rules for dygraph forward_autograd
// there are three kind of jvp rules
// 1. elementwise functions, e.g. tanh, we can reuse their vjp rules and just
// replace the out_grad with input_tangent
// 2. linear functions with single input, e.g. scale, we can reuse their forward
// functions and just replace the input with input_tangent
// 3. other case, e.g. concat/stack/batch_norm, we need to implement jvp rules
// manually
//
#include "paddle/fluid/eager/api/manual/eager_manual/forward_grad/manual_jvp.h"
#include <vector>
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/optional.h"

Tensor concat_jvp(const std::vector<Tensor>& x_ts, Scalar axis) {
std::vector<Tensor> fw_grads;
for (const Tensor& t : x_ts) {
if (egr::EagerUtils::nullable_autograd_meta(t)) {
fw_grads.push_back(t._fw_grad(/*level*/ 0));
} else {
fw_grads.push_back(
paddle::experimental::zeros(t.shape(), t.dtype(), t.place()));
}
}
return concat_ad_func(fw_grads, axis);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Manual jvp rules for dygraph forward_autograd
// there are three kind of jvp rules
// 1. elementwise functions, e.g. tanh, we can reuse their vjp rules and just
// replace the out_grad with input_tangent
// 2. linear functions with single input, e.g. scale, we can reuse their forward
// functions and just replace the input with input_tangent
// 3. other case, e.g. concat/stack/batch_norm, we need to implement jvp rules
// manually
//
#pragma once

#include <vector>
#include "paddle/common/flags.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/fluid/imperative/amp_utils.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/platform/profiler/event_tracing.h"

using Scalar = paddle::Scalar;
using Tensor = paddle::Tensor;

// void add_jvp(const paddle::Tensor& x_p,
// const paddle::Tensor& x_t,
// const paddle::Tensor& y_p,
// const paddle::Tensor& y_t,
// paddle::Tensor* out_t);

// void scale_jvp(const paddle::Tensor& x_p,
// const paddle::Tensor& x_t,
// paddle::experimental::Scalar scale,
// paddle::experimental::Scalar bias,
// bool bias_after_scale,
// paddle::Tensor* out_t);

// void tanh_jvp(const paddle::Tensor& x_t,
// const paddle::Tensor& out_p,
// paddle::Tensor* out_t);

paddle::Tensor concat_jvp(const std::vector<Tensor>& x_ts, Scalar axis);
13 changes: 13 additions & 0 deletions paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def __init__(self, forward_api_contents, namespace):
) # {name: func_name, args: [input_name, ...]}
self.intermediate_outputs = [] # [name, ...]
self.forward_inplace_map = {} # {name : name, ...}
self.jvp_rule = None # jvp_rule: auto_linear/auto_elementwise/xxx_jvp

def ParseForwardInplaceInfo(self):
forward_api_contents = self.forward_api_contents
Expand Down Expand Up @@ -556,6 +557,18 @@ def ParseIntermediate(self):
name = RemoveSpecialSymbolsInName(name)
self.intermediate_outputs.append(name)

def ParseJvpRule(self):
self.jvp_rule = self.forward_api_contents.get('jvp_rule', None)
if (
self.jvp_rule is not None
and self.jvp_rule != "auto_elementwise"
and self.jvp_rule != "auto_linear"
and not self.jvp_rule.endswith("_jvp")
):
raise ValueError(
f"The JVP rule of the operator '{self.forward_api_name}' should be 'auto_elementwise', 'auto_linear', or 'xxx_jvp', but got {self.jvp_rule}."
)

def CollectOriginalForwardInfo(self):
forward_api_contents = self.forward_api_contents
assert (
Expand Down
Loading
Loading