Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Auto-Parallel | Comm】fix communication hang issue on GPU-H(VPP) #71104

Open
wants to merge 17 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
23694c1
支持非均衡VPP编排的灵活模型层分配策略
zty-king Dec 14, 2024
a067477
支持非均衡VPP编排的灵活模型层分配策略
zty-king Dec 15, 2024
e3e1e71
支持非均衡VPP编排的灵活模型层分配策略
zty-king Dec 15, 2024
708b467
支持非均衡VPP编排的灵活模型层分配策略
zty-king Dec 15, 2024
abd297a
支持非均衡VPP编排的灵活模型层分配策略
zty-king Dec 16, 2024
b4cc018
支持非均衡VPP编排的灵活模型层分配策略
zty-king Dec 16, 2024
4f7e586
支持非均衡VPP编排的灵活模型层分配策略
zty-king Dec 18, 2024
b0d9ac9
支持非均衡VPP编排的灵活模型层分配策略
zty-king Jan 8, 2025
30f543a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zty-king Jan 8, 2025
5c6d235
支持非均衡VPP编排的灵活模型层分配策略
zty-king Jan 9, 2025
94a0d32
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zty-king Jan 9, 2025
9fc422d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zty-king Feb 6, 2025
7f42ad0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zty-king Feb 12, 2025
38b8506
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zty-king Feb 12, 2025
94c453e
fix communication hang issue on GPU-H(vpp)
zty-king Feb 12, 2025
602dde1
fix communication hang issue on GPU-H(VPP)
zty-king Feb 13, 2025
665ae80
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zty-king Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/paddle/distributed/auto_parallel/static/pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,8 +1112,9 @@ def complete_chunk_id(dist_program, startup_program, pipeline_strategy):
seg_chunk_ids = [i // pp_degree for i in range(num_chunks)]
seg_parts = [0]
last_struct_name = None
# stage_ids[i] represents the stage number assigned to the i-th layer.
stage_ids = []
stage_ids = (
[]
) # stage_ids[i] represents the stage number assigned to the i-th layer.

for idx, op in enumerate(ops):
if len(seg_parts) == len(seg_struct_names):
Expand Down
157 changes: 128 additions & 29 deletions python/paddle/distributed/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,6 @@ def add_persistable_var(op_idx, program_type):
type_to_ops[program_type][op_idx].result(idx),
var_name,
)
# type_to_ops[program_type][op_idx].result(idx).persistable = True

program_block = type_to_program[type].global_block()
new_result_var = program_block.add_kwarg(
Expand All @@ -1111,36 +1110,106 @@ def add_persistable_var(op_idx, program_type):
new_result_var
)

_add_dependency_if_necessary(
program_type, type, op_idx, idx, var_name
)

for type in following_program_types:
type_to_ops[type][op_idx].erase()

def _add_dependency(recorder_op, waiter_op, name):
'''
Add the extra event dependency of the two operators.
This function mainly aims for the cross-programs in pipeline parallelism,
especial for the 'send_v2' 'recv_v2' etc.
'''
if not recorder_op.has_attr("force_record_event"):
recorder_op.set_bool_attr("force_record_event", True)
recorder_op.set_str_attr("event_to_record", name)
waiter_op.set_str_array_attr("events_to_wait", [name])

def _add_dependency_if_necessary(
cur_job_type, next_job_type, op_idx, rst_idx, var_name
):
if not (
("backward" in cur_job_type and "send_backward" in next_job_type)
or ("recv_forward" in cur_job_type and "forward" in next_job_type)
):
return

first_used_idx = None
first_used_op = None
for used_op in (
type_to_ops[next_job_type][op_idx].result(rst_idx).all_used_ops()
):
used_idx = type_to_ops[next_job_type].index(used_op)
if first_used_idx is None or used_idx < first_used_idx:
first_used_idx = used_idx
first_used_op = used_op

if first_used_op is not None:
_add_dependency(
type_to_ops[cur_job_type][op_idx], first_used_op, var_name
)

type_to_program = OrderedDict()
type_to_ops = OrderedDict()

# Step1: create programs and ops for each type
for type in oprole_names:
if type == "optimizer":
type_to_program["optimizer"] = program.clone()
type_to_ops["optimizer"] = (
type_to_program["optimizer"].global_block().ops
if not split_bw:
chunk_ids = list(range(num_model_chunks))
# Forward process: the recv and forward of each chunk are put together
for chunk_id in chunk_ids:
type_to_program[f"recv_forward{chunk_id}"] = program.clone()
type_to_ops[f"recv_forward{chunk_id}"] = (
type_to_program[f"recv_forward{chunk_id}"].global_block().ops
)
else:
chunk_ids = list(range(num_model_chunks))
if "backward" in type:
chunk_ids.reverse()
for chunk_id in chunk_ids:
type_to_program[type + str(chunk_id)] = program.clone()
type_to_ops[type + str(chunk_id)] = (
type_to_program[type + str(chunk_id)].global_block().ops

type_to_program[f"forward{chunk_id}"] = program.clone()
type_to_ops[f"forward{chunk_id}"] = (
type_to_program[f"forward{chunk_id}"].global_block().ops
)

# Reverse process: the backward and send of each chunk are put together
for chunk_id in reversed(chunk_ids):
type_to_program[f"backward{chunk_id}"] = program.clone()
type_to_ops[f"backward{chunk_id}"] = (
type_to_program[f"backward{chunk_id}"].global_block().ops
)

type_to_program[f"send_backward{chunk_id}"] = program.clone()
type_to_ops[f"send_backward{chunk_id}"] = (
type_to_program[f"send_backward{chunk_id}"].global_block().ops
)

type_to_program["optimizer"] = program.clone()
type_to_ops["optimizer"] = (
type_to_program["optimizer"].global_block().ops
)
else:
for type in oprole_names:
if type == "optimizer":
type_to_program["optimizer"] = program.clone()
type_to_ops["optimizer"] = (
type_to_program["optimizer"].global_block().ops
)
else:
chunk_ids = list(range(num_model_chunks))
if "backward" in type:
chunk_ids.reverse()
for chunk_id in chunk_ids:
type_to_program[type + str(chunk_id)] = program.clone()
type_to_ops[type + str(chunk_id)] = (
type_to_program[type + str(chunk_id)].global_block().ops
)

# Step2: delete the ops not belong to the type
# 1. delete ops
# 2. add persistable var used between multiple programs
all_ops = program.global_block().ops
chunk_ids = list(range(num_model_chunks))

bwd_pattern_ops_type = []

for idx in range(len(all_ops) - 1, -1, -1):
op = all_ops[idx]
op_role = op.op_role
Expand All @@ -1161,26 +1230,50 @@ def add_persistable_var(op_idx, program_type):
bwd_pattern_ops_type = _pir_get_backward_op_type(all_ops, idx)
job_type = bwd_pattern_ops_type.pop()
elif op_role == int(OpRole.Backward) and (not split_bw):
job_type = "backward"
if op.name() == "pd_op.send_v2":
job_type = "send_backward"
else:
job_type = "backward"
elif op_role == int(OpRole.Forward):
job_type = "forward"
if op.name() == "pd_op.recv_v2" and (not split_bw):
job_type = "recv_forward"
else:
job_type = "forward"
else:
raise ValueError(
f"The op[{op.name()}]'s op role: {op_role} isn't one of Forward, Backward or Optimizer."
f"The op[{op.name()}]'s op role: {op_role} isn't one of recv_forward, forward, backward, send_backward or Optimizer."
)

# Step2.3: delete ops not belong to the type
for type in oprole_names:
if type == job_type:
break
for chunk_id in chunk_ids:
type_to_ops[type + str(chunk_id)][idx].erase()
if not split_bw:
current_type = (
job_type
if job_type == "optimizer"
else job_type + str(op_chunk_id)
)

# Get the position of the current type in type_to_program
all_types = list(type_to_ops.keys())
current_idx = all_types.index(current_type)

chunk_order = range(0, op_chunk_id)
if "backward" in job_type:
chunk_order = range(num_model_chunks - 1, op_chunk_id, -1)
for chunk_id in chunk_order:
type_to_ops[job_type + str(chunk_id)][idx].erase()
# Delete all ops before the current type
for type_name in all_types[:current_idx]:
type_to_ops[type_name][idx].erase()
else:
for type in oprole_names:
if type == job_type:
break
if type != "optimizer":
for chunk_id in chunk_ids:
type_to_ops[type + str(chunk_id)][idx].erase()
else:
type_to_ops[type][idx].erase()

chunk_order = range(0, op_chunk_id)
if "backward" in job_type:
chunk_order = range(num_model_chunks - 1, op_chunk_id, -1)
for chunk_id in chunk_order:
type_to_ops[job_type + str(chunk_id)][idx].erase()

# Step2.4: add persistable var used between multiple programs
if job_type != "optimizer":
Expand All @@ -1194,7 +1287,13 @@ def _pir_program_for_vpp(
):
_pir_overlap_send_recv(program)

oprole_names = ["forward", "backward", "optimizer"]
oprole_names = [
"recv_forward",
"forward",
"backward",
"send_backward",
"optimizer",
]
if split_bw:
oprole_names = ["forward", "backward_b", "backward_w", "optimizer"]

Expand Down
Loading