Skip to content

Commit c776929

Browse files
committed
fix: reset node_retry_counts on fresh node visits in feedback-loop graphs
Fixes #6605 node_retry_counts was never reset when a node was re-visited via an ON_FAILURE feedback edge. From visit 2 onward the node's accumulated failure count was already >= max_retries before making a single attempt, granting zero retries and immediately routing to failure again. The fix adds one line to the fresh-visit bookkeeping block at line 823: node_retry_counts[current_node_id] = 0 when _is_retry is False. This mirrors the existing node_visit_counts handling and gives each fresh visit a clean retry budget. _is_retry=True skips this block during retry iterations so intra-visit retry accumulation is unaffected. Tests added (3 new, all passing): - test_retry_budget_reset_on_revisit_via_on_failure: fails 3x then succeeds on visit 2 attempt 1 — verifies reset works - test_retry_budget_independent_per_visit: two full failure cycles then success — verifies each visit gets independent budget - test_linear_graph_retry_unaffected: confirms existing retry behavior unchanged for graphs without feedback loops
1 parent 9c0ba77 commit c776929

2 files changed

Lines changed: 351 additions & 18 deletions

File tree

core/framework/graph/executor.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,8 @@ async def execute(
550550
path: list[str] = []
551551
total_tokens = 0
552552
total_latency = 0
553-
node_retry_counts: dict[str, int] = {} # Track retries per node
553+
node_retry_counts: dict[str, int] = {} # Per-visit retry budget (reset on fresh visits)
554+
node_retry_totals: dict[str, int] = {} # Cumulative retry counts (never reset)
554555
node_visit_counts: dict[str, int] = {} # Track visits for feedback loops
555556
_is_retry = False # True when looping back for a retry (not a new visit)
556557

@@ -821,6 +822,7 @@ async def execute(
821822
if not _is_retry:
822823
cnt = node_visit_counts.get(current_node_id, 0) + 1
823824
node_visit_counts[current_node_id] = cnt
825+
node_retry_counts[current_node_id] = 0 # fresh visit = fresh retry budget
824826
_is_retry = False
825827
max_visits = getattr(node_spec, "max_node_visits", 0)
826828
if max_visits > 0 and node_visit_counts[current_node_id] > max_visits:
@@ -946,7 +948,7 @@ async def execute(
946948
current_node=node_spec.id,
947949
execution_path=list(path),
948950
memory=memory,
949-
is_clean=(sum(node_retry_counts.values()) == 0),
951+
is_clean=(sum(node_retry_totals.values()) == 0),
950952
)
951953

952954
if checkpoint_config.async_checkpoint:
@@ -1080,6 +1082,9 @@ async def execute(
10801082
node_retry_counts[current_node_id] = (
10811083
node_retry_counts.get(current_node_id, 0) + 1
10821084
)
1085+
node_retry_totals[current_node_id] = (
1086+
node_retry_totals.get(current_node_id, 0) + 1
1087+
)
10831088

10841089
# [CORRECTED] Use node_spec.max_retries instead of hardcoded 3
10851090
max_retries = getattr(node_spec, "max_retries", 3)
@@ -1166,8 +1171,8 @@ async def execute(
11661171
)
11671172

11681173
# Calculate quality metrics
1169-
total_retries_count = sum(node_retry_counts.values())
1170-
nodes_failed = list(node_retry_counts.keys())
1174+
total_retries_count = sum(node_retry_totals.values())
1175+
nodes_failed = list(node_retry_totals.keys())
11711176

11721177
if self.runtime_logger:
11731178
await self.runtime_logger.end_run(
@@ -1199,7 +1204,7 @@ async def execute(
11991204
path=path,
12001205
total_retries=total_retries_count,
12011206
nodes_with_failures=nodes_failed,
1202-
retry_details=dict(node_retry_counts),
1207+
retry_details=dict(node_retry_totals),
12031208
had_partial_failures=len(nodes_failed) > 0,
12041209
execution_quality="failed",
12051210
node_visit_counts=dict(node_visit_counts),
@@ -1237,8 +1242,8 @@ async def execute(
12371242
)
12381243

12391244
# Calculate quality metrics
1240-
total_retries_count = sum(node_retry_counts.values())
1241-
nodes_failed = [nid for nid, count in node_retry_counts.items() if count > 0]
1245+
total_retries_count = sum(node_retry_totals.values())
1246+
nodes_failed = [nid for nid, count in node_retry_totals.items() if count > 0]
12421247
exec_quality = "degraded" if total_retries_count > 0 else "clean"
12431248

12441249
if self.runtime_logger:
@@ -1260,7 +1265,7 @@ async def execute(
12601265
session_state=session_state_out,
12611266
total_retries=total_retries_count,
12621267
nodes_with_failures=nodes_failed,
1263-
retry_details=dict(node_retry_counts),
1268+
retry_details=dict(node_retry_totals),
12641269
had_partial_failures=len(nodes_failed) > 0,
12651270
execution_quality=exec_quality,
12661271
node_visit_counts=dict(node_visit_counts),
@@ -1387,7 +1392,7 @@ async def execute(
13871392
execution_path=list(path),
13881393
memory=memory,
13891394
next_node=next_node,
1390-
is_clean=(sum(node_retry_counts.values()) == 0),
1395+
is_clean=(sum(node_retry_totals.values()) == 0),
13911396
)
13921397

13931398
if checkpoint_config.async_checkpoint:
@@ -1577,8 +1582,8 @@ async def execute(
15771582
self.logger.info(f" Total latency: {total_latency}ms")
15781583

15791584
# Calculate execution quality metrics
1580-
total_retries_count = sum(node_retry_counts.values())
1581-
nodes_failed = [nid for nid, count in node_retry_counts.items() if count > 0]
1585+
total_retries_count = sum(node_retry_totals.values())
1586+
nodes_failed = [nid for nid, count in node_retry_totals.items() if count > 0]
15821587
exec_quality = "degraded" if total_retries_count > 0 else "clean"
15831588

15841589
# Update narrative to reflect execution quality
@@ -1613,7 +1618,7 @@ async def execute(
16131618
path=path,
16141619
total_retries=total_retries_count,
16151620
nodes_with_failures=nodes_failed,
1616-
retry_details=dict(node_retry_counts),
1621+
retry_details=dict(node_retry_totals),
16171622
had_partial_failures=len(nodes_failed) > 0,
16181623
execution_quality=exec_quality,
16191624
node_visit_counts=dict(node_visit_counts),
@@ -1665,8 +1670,8 @@ async def execute(
16651670
}
16661671

16671672
# Calculate quality metrics
1668-
total_retries_count = sum(node_retry_counts.values())
1669-
nodes_failed = [nid for nid, count in node_retry_counts.items() if count > 0]
1673+
total_retries_count = sum(node_retry_totals.values())
1674+
nodes_failed = [nid for nid, count in node_retry_totals.items() if count > 0]
16701675
exec_quality = "degraded" if total_retries_count > 0 else "clean"
16711676

16721677
if self.runtime_logger:
@@ -1690,7 +1695,7 @@ async def execute(
16901695
session_state=session_state_out,
16911696
total_retries=total_retries_count,
16921697
nodes_with_failures=nodes_failed,
1693-
retry_details=dict(node_retry_counts),
1698+
retry_details=dict(node_retry_totals),
16941699
had_partial_failures=len(nodes_failed) > 0,
16951700
execution_quality=exec_quality,
16961701
node_visit_counts=dict(node_visit_counts),
@@ -1722,8 +1727,8 @@ async def execute(
17221727
)
17231728

17241729
# Calculate quality metrics even for exceptions
1725-
total_retries_count = sum(node_retry_counts.values())
1726-
nodes_failed = list(node_retry_counts.keys())
1730+
total_retries_count = sum(node_retry_totals.values())
1731+
nodes_failed = list(node_retry_totals.keys())
17271732

17281733
if self.runtime_logger:
17291734
await self.runtime_logger.end_run(
@@ -1789,7 +1794,7 @@ async def execute(
17891794
path=path,
17901795
total_retries=total_retries_count,
17911796
nodes_with_failures=nodes_failed,
1792-
retry_details=dict(node_retry_counts),
1797+
retry_details=dict(node_retry_totals),
17931798
had_partial_failures=len(nodes_failed) > 0,
17941799
execution_quality="failed",
17951800
node_visit_counts=dict(node_visit_counts),

0 commit comments

Comments
 (0)