Skip to content

Commit b9f29d1

Browse files
committed
Add AdaptiveModel for custom model selection logic
1 parent 5768447 commit b9f29d1

File tree

1 file changed

+297
-0
lines changed

1 file changed

+297
-0
lines changed
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
from __future__ import annotations as _annotations
2+
3+
import inspect
4+
import time
5+
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
6+
from contextlib import AsyncExitStack, asynccontextmanager, suppress
7+
from dataclasses import dataclass
8+
from typing import TYPE_CHECKING, Generic, TypeVar
9+
10+
from opentelemetry.trace import get_current_span
11+
12+
from pydantic_ai._run_context import RunContext
13+
from pydantic_ai.models.instrumented import InstrumentedModel
14+
15+
from ..exceptions import FallbackExceptionGroup
16+
from ..settings import merge_model_settings
17+
from . import Model, ModelRequestParameters, StreamedResponse
18+
19+
if TYPE_CHECKING:
20+
from ..messages import ModelMessage, ModelResponse
21+
from ..settings import ModelSettings
22+
23+
AgentDepsT = TypeVar('AgentDepsT')
24+
25+
26+
@dataclass
27+
class AttemptResult:
28+
"""Record of a single attempt to use a model."""
29+
30+
model: Model
31+
"""The model that was attempted."""
32+
33+
exception: Exception | None
34+
"""The exception raised by the model, if any."""
35+
36+
timestamp: float
37+
"""Unix timestamp when the attempt was made."""
38+
39+
duration: float
40+
"""Duration of the attempt in seconds."""
41+
42+
43+
@dataclass
44+
class AdaptiveContext(Generic[AgentDepsT]):
45+
"""Context provided to the selector function."""
46+
47+
run_context: RunContext[AgentDepsT] | None
48+
"""Access to agent dependencies. May be None for non-streaming requests."""
49+
50+
models: Sequence[Model]
51+
"""Available models to choose from."""
52+
53+
attempts: list[AttemptResult]
54+
"""History of attempts in this request."""
55+
56+
attempt_number: int
57+
"""Current attempt number (1-indexed)."""
58+
59+
messages: list[ModelMessage]
60+
"""The original request messages."""
61+
62+
model_settings: ModelSettings | None
63+
"""Model settings for this request."""
64+
65+
model_request_parameters: ModelRequestParameters
66+
"""Model request parameters."""
67+
68+
69+
@dataclass(init=False)
70+
class AdaptiveModel(Model, Generic[AgentDepsT]):
71+
"""A model that uses custom logic to select which model to try next.
72+
73+
Unlike FallbackModel which tries models sequentially, AdaptiveModel gives
74+
full control over model selection based on rich context including attempts,
75+
exceptions, and agent dependencies.
76+
77+
The selector function is called before each attempt and can:
78+
- Return a Model to try next (can be the same model for retry)
79+
- Return None to stop trying
80+
- Use async/await for delays (exponential backoff, etc.)
81+
- Access agent dependencies via ctx.run_context.deps
82+
- Inspect previous attempts via ctx.attempts
83+
"""
84+
85+
models: Sequence[Model]
86+
_selector: (
87+
Callable[[AdaptiveContext[AgentDepsT]], Model | None]
88+
| Callable[[AdaptiveContext[AgentDepsT]], Awaitable[Model | None]]
89+
)
90+
_max_attempts: int | None
91+
92+
def __init__(
93+
self,
94+
models: Sequence[Model],
95+
selector: Callable[[AdaptiveContext[AgentDepsT]], Model | None]
96+
| Callable[[AdaptiveContext[AgentDepsT]], Awaitable[Model | None]],
97+
*,
98+
max_attempts: int | None = None,
99+
):
100+
"""Initialize an adaptive model instance.
101+
102+
Args:
103+
models: Pool of models to choose from.
104+
selector: Sync or async function that selects the next model to try.
105+
Called before each attempt with context including previous attempts.
106+
Return a Model to try, or None to stop.
107+
max_attempts: Maximum total attempts across all models (None = unlimited).
108+
"""
109+
super().__init__()
110+
if not models:
111+
raise ValueError('At least one model must be provided')
112+
113+
self.models = list(models)
114+
self._selector = selector
115+
self._max_attempts = max_attempts
116+
117+
@property
118+
def model_name(self) -> str:
119+
"""The model name."""
120+
return f'adaptive:{",".join(model.model_name for model in self.models)}'
121+
122+
@property
123+
def system(self) -> str:
124+
return f'adaptive:{",".join(model.system for model in self.models)}'
125+
126+
@property
127+
def base_url(self) -> str | None:
128+
return self.models[0].base_url if self.models else None
129+
130+
async def request(
131+
self,
132+
messages: list[ModelMessage],
133+
model_settings: ModelSettings | None,
134+
model_request_parameters: ModelRequestParameters,
135+
) -> ModelResponse:
136+
"""Try models based on selector logic until one succeeds or selector returns None."""
137+
attempts: list[AttemptResult] = []
138+
attempt_number = 0
139+
140+
while True:
141+
attempt_number += 1
142+
143+
# Check max attempts
144+
if self._max_attempts is not None and attempt_number > self._max_attempts:
145+
exceptions = [a.exception for a in attempts if a.exception is not None]
146+
if exceptions:
147+
raise FallbackExceptionGroup(
148+
f'AdaptiveModel exceeded max_attempts of {self._max_attempts}', exceptions
149+
)
150+
else:
151+
raise FallbackExceptionGroup(
152+
f'AdaptiveModel exceeded max_attempts of {self._max_attempts}',
153+
[RuntimeError('No models were attempted')],
154+
)
155+
156+
# Create context for selector
157+
context = AdaptiveContext(
158+
run_context=None, # run_context not available in non-streaming request
159+
models=self.models,
160+
attempts=attempts,
161+
attempt_number=attempt_number,
162+
messages=messages,
163+
model_settings=model_settings,
164+
model_request_parameters=model_request_parameters,
165+
)
166+
167+
# Call selector to get next model
168+
model = await self._call_selector(context)
169+
170+
if model is None:
171+
# Selector says stop trying
172+
exceptions = [a.exception for a in attempts if a.exception is not None]
173+
if exceptions:
174+
raise FallbackExceptionGroup('AdaptiveModel selector returned None', exceptions)
175+
else:
176+
raise FallbackExceptionGroup(
177+
'AdaptiveModel selector returned None', [RuntimeError('No models were attempted')]
178+
)
179+
180+
# Try the selected model
181+
start_time = time.time()
182+
customized_params = model.customize_request_parameters(model_request_parameters)
183+
merged_settings = merge_model_settings(model.settings, model_settings)
184+
185+
try:
186+
response = await model.request(messages, merged_settings, customized_params)
187+
# Success! Set span attributes and return
188+
self._set_span_attributes(model)
189+
return response
190+
except Exception as exc:
191+
# Record the attempt
192+
duration = time.time() - start_time
193+
attempts.append(
194+
AttemptResult(
195+
model=model,
196+
exception=exc,
197+
timestamp=start_time,
198+
duration=duration,
199+
)
200+
)
201+
# Continue loop to try again
202+
203+
@asynccontextmanager
204+
async def request_stream(
205+
self,
206+
messages: list[ModelMessage],
207+
model_settings: ModelSettings | None,
208+
model_request_parameters: ModelRequestParameters,
209+
run_context: RunContext[AgentDepsT] | None = None,
210+
) -> AsyncIterator[StreamedResponse]:
211+
"""Try models based on selector logic until one succeeds or selector returns None."""
212+
attempts: list[AttemptResult] = []
213+
attempt_number = 0
214+
215+
while True:
216+
attempt_number += 1
217+
218+
# Check max attempts
219+
if self._max_attempts is not None and attempt_number > self._max_attempts:
220+
exceptions = [a.exception for a in attempts if a.exception is not None]
221+
if exceptions:
222+
raise FallbackExceptionGroup(
223+
f'AdaptiveModel exceeded max_attempts of {self._max_attempts}', exceptions
224+
)
225+
else:
226+
raise FallbackExceptionGroup(
227+
f'AdaptiveModel exceeded max_attempts of {self._max_attempts}',
228+
[RuntimeError('No models were attempted')],
229+
)
230+
231+
# Create context for selector
232+
context = AdaptiveContext(
233+
run_context=run_context,
234+
models=self.models,
235+
attempts=attempts,
236+
attempt_number=attempt_number,
237+
messages=messages,
238+
model_settings=model_settings,
239+
model_request_parameters=model_request_parameters,
240+
)
241+
242+
# Call selector to get next model
243+
model = await self._call_selector(context)
244+
245+
if model is None:
246+
# Selector says stop trying
247+
exceptions = [a.exception for a in attempts if a.exception is not None]
248+
if exceptions:
249+
raise FallbackExceptionGroup('AdaptiveModel selector returned None', exceptions)
250+
else:
251+
raise FallbackExceptionGroup(
252+
'AdaptiveModel selector returned None', [RuntimeError('No models were attempted')]
253+
)
254+
255+
# Try the selected model
256+
start_time = time.time()
257+
customized_params = model.customize_request_parameters(model_request_parameters)
258+
merged_settings = merge_model_settings(model.settings, model_settings)
259+
260+
async with AsyncExitStack() as stack:
261+
try:
262+
response = await stack.enter_async_context(
263+
model.request_stream(messages, merged_settings, customized_params, run_context)
264+
)
265+
except Exception as exc:
266+
# Record the attempt and continue
267+
duration = time.time() - start_time
268+
attempts.append(
269+
AttemptResult(
270+
model=model,
271+
exception=exc,
272+
timestamp=start_time,
273+
duration=duration,
274+
)
275+
)
276+
continue
277+
278+
# Success! Set span attributes and yield
279+
self._set_span_attributes(model)
280+
yield response
281+
return
282+
283+
async def _call_selector(self, context: AdaptiveContext[AgentDepsT]) -> Model | None:
284+
"""Call the selector function, handling both sync and async."""
285+
if inspect.iscoroutinefunction(self._selector):
286+
return await self._selector(context)
287+
else:
288+
return self._selector(context) # type: ignore
289+
290+
def _set_span_attributes(self, model: Model):
291+
"""Set OpenTelemetry span attributes for the successful model."""
292+
with suppress(Exception):
293+
span = get_current_span()
294+
if span.is_recording():
295+
attributes = getattr(span, 'attributes', {})
296+
if attributes.get('gen_ai.request.model') == self.model_name: # pragma: no branch
297+
span.set_attributes(InstrumentedModel.model_attributes(model))

0 commit comments

Comments
 (0)