Skip to content

Commit 9a78b3e

Browse files
authored
fix(autofix) Retry stream from scratch (#1748)
[Investigation](https://github.com/getsentry/ml-models/blob/main/autofix/retry_stream/investigate_errors.ipynb) found that backoff would address almost all of the overloaded errors. Fix #1671 by retrying the completion from scratch, unlike #1675. ## How has this been tested? New unit tests and local runs described below <details> <summary>Local autofix runs</summary> Repro issue locally: 1. `git checkout main` 2. Replace all `AnthropicProvider`s with `AnthropicProviderFlaky`s in `coding/components.py`. 3. Run the Autofix coding step on a local issue, confirm you see the overloaded error and the run fails. 4. `git stash` the stuff in step 2 b/c we'll use it again when testing the fix. Test fix: 1. `git checkout kddubey/autofix/retry-stream-from-scratch` 2. `git stash pop` to replace all `AnthropicProvider`s with `AnthropicProviderFlaky`s in `coding/components.py`. 3. Run Autofix on a local issue, confirm you see the overloaded error in your logs, but it quickly tries again and finishes the run. 4. `git stash` to undo step (2) to test an unflaky API 5. Run Autofix on the same issue, confirm the run finishes fine. </summary>
1 parent b21b097 commit 9a78b3e

9 files changed

+356
-3
lines changed

src/seer/automation/agent/client.py

+17
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ def _get_config(cls, model_name: str):
9090
return config
9191
return None
9292

93+
@staticmethod
94+
def is_completion_exception_retryable(exception: Exception) -> bool:
95+
return False
96+
9397
def generate_text(
9498
self,
9599
*,
@@ -408,6 +412,15 @@ def _get_config(cls, model_name: str):
408412
return config
409413
return None
410414

415+
@staticmethod
416+
def is_completion_exception_retryable(exception: Exception) -> bool:
417+
if isinstance(exception, anthropic.APIStatusError):
418+
return exception.status_code == 529
419+
# https://docs.anthropic.com/en/api/errors#http-errors
420+
return isinstance(exception, anthropic.AnthropicError) and (
421+
"overloaded_error" in str(exception)
422+
)
423+
411424
@observe(as_type="generation", name="Anthropic Generation")
412425
@inject
413426
def generate_text(
@@ -708,6 +721,10 @@ def search_the_web(self, prompt: str, temperature: float | None = None) -> str:
708721
answer += each.text
709722
return answer
710723

724+
@staticmethod
725+
def is_completion_exception_retryable(exception: Exception) -> bool:
726+
return False
727+
711728
@observe(as_type="generation", name="Gemini Generation")
712729
def generate_structured(
713730
self,

src/seer/automation/autofix/autofix_agent.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import contextlib
22
import logging
33
from concurrent.futures import Executor, Future, ThreadPoolExecutor
4-
from typing import Optional
4+
from typing import Callable, Optional
55

66
from seer.automation.agent.agent import AgentConfig, LlmAgent, RunConfig
77
from seer.automation.agent.models import (
@@ -17,6 +17,7 @@
1717
from seer.automation.autofix.models import AutofixContinuation, AutofixStatus, DefaultStep
1818
from seer.automation.state import State
1919
from seer.dependency_injection import copy_modules_initializer
20+
from seer.utils import backoff_on_exception
2021

2122
logger = logging.getLogger(__name__)
2223

@@ -69,7 +70,7 @@ def _check_prompt_for_help(self, run_config: RunConfig):
6970
"You're taking a while. If you need help, ask me a concrete question using the tool provided."
7071
)
7172

72-
def get_completion(self, run_config: RunConfig):
73+
def _get_completion(self, run_config: RunConfig):
7374
"""
7475
Streams the preliminary output to the current step and only returns when output is complete
7576
"""
@@ -118,6 +119,27 @@ def get_completion(self, run_config: RunConfig):
118119
),
119120
)
120121

122+
def get_completion(
123+
self,
124+
run_config: RunConfig,
125+
max_tries: int = 4,
126+
sleep_sec_scaler: Callable[[int], float] = lambda num_tries: 2**num_tries,
127+
):
128+
"""
129+
Streams the preliminary output to the current step and only returns when output is complete.
130+
131+
The completion request is retried `max_tries - 1` times if a retryable exception was just
132+
raised, e.g, Anthropic's API is overloaded.
133+
"""
134+
is_exception_retryable = getattr(
135+
run_config.model, "is_completion_exception_retryable", lambda _: False
136+
)
137+
retrier = backoff_on_exception(
138+
is_exception_retryable, max_tries=max_tries, sleep_sec_scaler=sleep_sec_scaler
139+
)
140+
get_completion_retryable = retrier(self._get_completion)
141+
return get_completion_retryable(run_config)
142+
121143
def run_iteration(self, run_config: RunConfig):
122144
logger.debug(f"----[{self.name}] Running Iteration {self.iterations}----")
123145

src/seer/utils.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import contextlib
22
import functools
33
import json
4+
import logging
5+
import random
6+
import time
47
import weakref
58
from enum import Enum
69
from queue import Empty, Full, Queue
7-
from typing import Sequence
10+
from typing import Callable, Sequence
811

912
from sqlalchemy.orm import DeclarativeBase, Session
1013

14+
logger = logging.getLogger(__name__)
15+
1116

1217
def class_method_lru_cache(*lru_args, **lru_kwargs):
1318
def decorator(func):
@@ -68,3 +73,44 @@ def closing_queue(*queues: Queue):
6873
queue.get_nowait()
6974
except Empty:
7075
pass
76+
77+
78+
def backoff_on_exception(
79+
is_exception_retryable: Callable[[Exception], bool],
80+
max_tries: int = 2,
81+
sleep_sec_scaler: Callable[[int], float] = lambda num_tries: 2**num_tries,
82+
jitterer: Callable[[], float] = lambda: random.uniform(0, 0.5),
83+
):
84+
"""
85+
Returns a decorator which retries a function on exception iff `is_exception_retryable(exception)`.
86+
Defaults to exponential backoff with random jitter and one retry.
87+
"""
88+
89+
if max_tries < 1:
90+
raise ValueError("max_tries must be at least 1") # pragma: no cover
91+
92+
def decorator(func):
93+
@functools.wraps(func)
94+
def wrapped_func(*args, **kwargs):
95+
num_tries = 0
96+
last_exception = None
97+
while num_tries < max_tries:
98+
try:
99+
return func(*args, **kwargs)
100+
except Exception as exception:
101+
num_tries += 1
102+
last_exception = exception
103+
if is_exception_retryable(exception):
104+
sleep_sec = sleep_sec_scaler(num_tries) + jitterer()
105+
logger.info(
106+
f"Encountered {type(exception).__name__}: {exception}. Sleeping for "
107+
f"{sleep_sec} seconds before attempting retry {num_tries}/{max_tries}."
108+
)
109+
time.sleep(sleep_sec)
110+
else:
111+
raise exception
112+
raise last_exception
113+
114+
return wrapped_func
115+
116+
return decorator

0 commit comments

Comments
 (0)