Skip to content

Commit 57a3aa7

Browse files
authored
feat(gateway): support api_type (#3362)
1 parent faca9c4 commit 57a3aa7

File tree

14 files changed

+163
-114
lines changed

14 files changed

+163
-114
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,9 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
764764

765765
model_kind = provider_name
766766
if model_kind.startswith('gateway/'):
767-
model_kind = provider_name.removeprefix('gateway/')
767+
from ..providers.gateway import infer_gateway_model
768+
769+
return infer_gateway_model(model_kind.removeprefix('gateway/'), model_name=model_name)
768770
if model_kind in (
769771
'openai',
770772
'azure',

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def __init__(
226226
self._model_name = model_name
227227

228228
if isinstance(provider, str):
229-
provider = infer_provider('gateway/bedrock' if provider == 'gateway' else provider)
229+
provider = infer_provider('gateway/converse' if provider == 'gateway' else provider)
230230
self._provider = provider
231231
self.client = cast('BedrockRuntimeClient', provider.client)
232232

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(
204204
self._model_name = model_name
205205

206206
if isinstance(provider, str):
207-
provider = infer_provider('gateway/google-vertex' if provider == 'gateway' else provider)
207+
provider = infer_provider('gateway/gemini' if provider == 'gateway' else provider)
208208
self._provider = provider
209209
self.client = provider.client
210210

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def __init__(
375375
self._model_name = model_name
376376

377377
if isinstance(provider, str):
378-
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
378+
provider = infer_provider('gateway/chat' if provider == 'gateway' else provider)
379379
self._provider = provider
380380
self.client = provider.client
381381

@@ -944,7 +944,7 @@ def __init__(
944944
self._model_name = model_name
945945

946946
if isinstance(provider, str):
947-
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
947+
provider = infer_provider('gateway/responses' if provider == 'gateway' else provider)
948948
self._provider = provider
949949
self.client = provider.client
950950

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def infer_provider(provider: str) -> Provider[Any]:
158158
if provider.startswith('gateway/'):
159159
from .gateway import gateway_provider
160160

161-
provider = provider.removeprefix('gateway/')
162-
return gateway_provider(provider)
161+
api_type = provider.removeprefix('gateway/')
162+
return gateway_provider(api_type)
163163
elif provider in ('google-vertex', 'google-gla'):
164164
from .google import GoogleProvider
165165

pydantic_ai_slim/pydantic_ai/providers/gateway.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from groq import AsyncGroq
1818
from openai import AsyncOpenAI
1919

20+
from pydantic_ai.models import Model
2021
from pydantic_ai.models.anthropic import AsyncAnthropicClient
2122
from pydantic_ai.providers import Provider
2223

@@ -25,7 +26,8 @@
2526

2627
@overload
2728
def gateway_provider(
28-
upstream_provider: Literal['openai', 'openai-chat', 'openai-responses'],
29+
api_type: Literal['chat', 'responses'],
30+
/,
2931
*,
3032
api_key: str | None = None,
3133
base_url: str | None = None,
@@ -35,7 +37,8 @@ def gateway_provider(
3537

3638
@overload
3739
def gateway_provider(
38-
upstream_provider: Literal['groq'],
40+
api_type: Literal['groq'],
41+
/,
3942
*,
4043
api_key: str | None = None,
4144
base_url: str | None = None,
@@ -45,56 +48,63 @@ def gateway_provider(
4548

4649
@overload
4750
def gateway_provider(
48-
upstream_provider: Literal['google-vertex'],
51+
api_type: Literal['anthropic'],
52+
/,
4953
*,
5054
api_key: str | None = None,
5155
base_url: str | None = None,
52-
) -> Provider[GoogleClient]: ...
56+
http_client: httpx.AsyncClient | None = None,
57+
) -> Provider[AsyncAnthropicClient]: ...
5358

5459

5560
@overload
5661
def gateway_provider(
57-
upstream_provider: Literal['anthropic'],
62+
api_type: Literal['converse'],
63+
/,
5864
*,
5965
api_key: str | None = None,
6066
base_url: str | None = None,
61-
) -> Provider[AsyncAnthropicClient]: ...
67+
) -> Provider[BaseClient]: ...
6268

6369

6470
@overload
6571
def gateway_provider(
66-
upstream_provider: Literal['bedrock'],
72+
api_type: Literal['gemini'],
73+
/,
6774
*,
6875
api_key: str | None = None,
6976
base_url: str | None = None,
70-
) -> Provider[BaseClient]: ...
77+
http_client: httpx.AsyncClient | None = None,
78+
) -> Provider[GoogleClient]: ...
7179

7280

7381
@overload
7482
def gateway_provider(
75-
upstream_provider: str,
83+
api_type: str,
84+
/,
7685
*,
7786
api_key: str | None = None,
7887
base_url: str | None = None,
7988
) -> Provider[Any]: ...
8089

8190

82-
UpstreamProvider = Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic', 'bedrock']
91+
APIType = Literal['chat', 'responses', 'gemini', 'converse', 'anthropic', 'groq']
8392

8493

8594
def gateway_provider(
86-
upstream_provider: UpstreamProvider | str,
95+
api_type: APIType | str,
96+
/,
8797
*,
8898
# Every provider
8999
api_key: str | None = None,
90100
base_url: str | None = None,
91-
# OpenAI, Groq & Anthropic
101+
# OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client.
92102
http_client: httpx.AsyncClient | None = None,
93103
) -> Provider[Any]:
94104
"""Create a new Gateway provider.
95105
96106
Args:
97-
upstream_provider: The upstream provider to use.
107+
api_type: Determines the API type to use.
98108
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
99109
environment variable will be used if available.
100110
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
@@ -109,18 +119,18 @@ def gateway_provider(
109119
)
110120

111121
base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', GATEWAY_BASE_URL)
112-
http_client = http_client or cached_async_http_client(provider=f'gateway/{upstream_provider}')
122+
http_client = http_client or cached_async_http_client(provider=f'gateway/{api_type}')
113123
http_client.event_hooks = {'request': [_request_hook(api_key)]}
114124

115-
if upstream_provider in ('openai', 'openai-chat', 'openai-responses'):
125+
if api_type in ('chat', 'responses'):
116126
from .openai import OpenAIProvider
117127

118-
return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client)
119-
elif upstream_provider == 'groq':
128+
return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, api_type), http_client=http_client)
129+
elif api_type == 'groq':
120130
from .groq import GroqProvider
121131

122132
return GroqProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'groq'), http_client=http_client)
123-
elif upstream_provider == 'anthropic':
133+
elif api_type == 'anthropic':
124134
from anthropic import AsyncAnthropic
125135

126136
from .anthropic import AnthropicProvider
@@ -132,25 +142,25 @@ def gateway_provider(
132142
http_client=http_client,
133143
)
134144
)
135-
elif upstream_provider == 'bedrock':
145+
elif api_type == 'converse':
136146
from .bedrock import BedrockProvider
137147

138148
return BedrockProvider(
139149
api_key=api_key,
140-
base_url=_merge_url_path(base_url, 'bedrock'),
150+
base_url=_merge_url_path(base_url, api_type),
141151
region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError
142152
)
143-
elif upstream_provider == 'google-vertex':
153+
elif api_type == 'gemini':
144154
from .google import GoogleProvider
145155

146156
return GoogleProvider(
147157
vertexai=True,
148158
api_key=api_key,
149-
base_url=_merge_url_path(base_url, 'google-vertex'),
159+
base_url=_merge_url_path(base_url, 'gemini'),
150160
http_client=http_client,
151161
)
152162
else:
153-
raise UserError(f'Unknown upstream provider: {upstream_provider}')
163+
raise UserError(f'Unknown API type: {api_type}')
154164

155165

156166
def _request_hook(api_key: str) -> Callable[[httpx.Request], Awaitable[httpx.Request]]:
@@ -182,3 +192,33 @@ def _merge_url_path(base_url: str, path: str) -> str:
182192
path: The path to merge.
183193
"""
184194
return base_url.rstrip('/') + '/' + path.lstrip('/')
195+
196+
197+
def infer_gateway_model(api_type: APIType | str, *, model_name: str) -> Model:
198+
"""Infer the model class for a given API type."""
199+
if api_type == 'chat':
200+
from pydantic_ai.models.openai import OpenAIChatModel
201+
202+
return OpenAIChatModel(model_name=model_name, provider='gateway')
203+
elif api_type == 'groq':
204+
from pydantic_ai.models.groq import GroqModel
205+
206+
return GroqModel(model_name=model_name, provider='gateway')
207+
elif api_type == 'responses':
208+
from pydantic_ai.models.openai import OpenAIResponsesModel
209+
210+
return OpenAIResponsesModel(model_name=model_name, provider='gateway')
211+
elif api_type == 'gemini':
212+
from pydantic_ai.models.google import GoogleModel
213+
214+
return GoogleModel(model_name=model_name, provider='gateway')
215+
elif api_type == 'converse':
216+
from pydantic_ai.models.bedrock import BedrockConverseModel
217+
218+
return BedrockConverseModel(model_name=model_name, provider='gateway')
219+
elif api_type == 'anthropic':
220+
from pydantic_ai.models.anthropic import AnthropicModel
221+
222+
return AnthropicModel(model_name=model_name, provider='gateway')
223+
else:
224+
raise ValueError(f'Unknown API type: {api_type}') # pragma: no cover

tests/models/cassettes/test_model_names/test_known_model_names.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,31 +134,31 @@ interactions:
134134
parsed_body:
135135
data:
136136
- created: 0
137-
id: qwen-3-32b
137+
id: qwen-3-235b-a22b-thinking-2507
138138
object: model
139139
owned_by: Cerebras
140140
- created: 0
141-
id: qwen-3-235b-a22b-instruct-2507
141+
id: llama-3.3-70b
142142
object: model
143143
owned_by: Cerebras
144144
- created: 0
145-
id: gpt-oss-120b
145+
id: qwen-3-235b-a22b-instruct-2507
146146
object: model
147147
owned_by: Cerebras
148148
- created: 0
149-
id: zai-glm-4.6
149+
id: qwen-3-32b
150150
object: model
151151
owned_by: Cerebras
152152
- created: 0
153-
id: llama3.1-8b
153+
id: zai-glm-4.6
154154
object: model
155155
owned_by: Cerebras
156156
- created: 0
157-
id: llama-3.3-70b
157+
id: gpt-oss-120b
158158
object: model
159159
owned_by: Cerebras
160160
- created: 0
161-
id: qwen-3-235b-a22b-thinking-2507
161+
id: llama3.1-8b
162162
object: model
163163
owned_by: Cerebras
164164
object: list

tests/models/test_model.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,21 @@
2929
TEST_CASES = [
3030
pytest.param(
3131
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
32-
'gateway/openai:gpt-5',
32+
'gateway/chat:gpt-5',
3333
'gpt-5',
3434
'openai',
3535
'openai',
3636
OpenAIChatModel,
37-
id='gateway/openai:gpt-5',
37+
id='gateway/chat:gpt-5',
3838
),
3939
pytest.param(
4040
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
41-
'gateway/openai-chat:gpt-5',
42-
'gpt-5',
43-
'openai',
44-
'openai',
45-
OpenAIChatModel,
46-
id='gateway/openai-chat:gpt-5',
47-
),
48-
pytest.param(
49-
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
50-
'gateway/openai-responses:gpt-5',
41+
'gateway/responses:gpt-5',
5142
'gpt-5',
5243
'openai',
5344
'openai',
5445
OpenAIResponsesModel,
55-
id='gateway/openai-responses:gpt-5',
46+
id='gateway/responses:gpt-5',
5647
),
5748
pytest.param(
5849
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
@@ -65,12 +56,12 @@
6556
),
6657
pytest.param(
6758
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
68-
'gateway/google-vertex:gemini-1.5-flash',
59+
'gateway/gemini:gemini-1.5-flash',
6960
'gemini-1.5-flash',
7061
'google-vertex',
7162
'google',
7263
GoogleModel,
73-
id='gateway/google-vertex:gemini-1.5-flash',
64+
id='gateway/gemini:gemini-1.5-flash',
7465
),
7566
pytest.param(
7667
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
@@ -83,12 +74,12 @@
8374
),
8475
pytest.param(
8576
{'PYDANTIC_AI_GATEWAY_API_KEY': 'gateway-api-key'},
86-
'gateway/bedrock:amazon.nova-micro-v1:0',
77+
'gateway/converse:amazon.nova-micro-v1:0',
8778
'amazon.nova-micro-v1:0',
8879
'bedrock',
8980
'bedrock',
9081
BedrockConverseModel,
91-
id='gateway/bedrock:amazon.nova-micro-v1:0',
82+
id='gateway/converse:amazon.nova-micro-v1:0',
9283
),
9384
pytest.param(
9485
{'OPENAI_API_KEY': 'openai-api-key'},

tests/providers/cassettes/test_gateway/test_gateway_provider_with_bedrock.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ interactions:
55
headers:
66
amz-sdk-invocation-id:
77
- !!binary |
8-
MWYwNDlkMTQtMjVmMC00YTRhLWJhYmMtNTQ0MDdhMmRlNjgw
8+
MmEzMzkzMGUtNzI3YS00YzFhLWFmYWQtYzFhYWMyMTI3NDlj
99
amz-sdk-request:
1010
- !!binary |
1111
YXR0ZW1wdD0x
@@ -15,7 +15,7 @@ interactions:
1515
- !!binary |
1616
YXBwbGljYXRpb24vanNvbg==
1717
method: POST
18-
uri: http://localhost:8787/bedrock/model/amazon.nova-micro-v1%3A0/converse
18+
uri: http://localhost:8787/converse/model/amazon.nova-micro-v1%3A0/converse
1919
response:
2020
headers:
2121
content-length:
@@ -26,7 +26,7 @@ interactions:
2626
- 0.0000USD
2727
parsed_body:
2828
metrics:
29-
latencyMs: 668
29+
latencyMs: 682
3030
output:
3131
message:
3232
content:

0 commit comments

Comments
 (0)