1717 from groq import AsyncGroq
1818 from openai import AsyncOpenAI
1919
20+ from pydantic_ai .models import Model
2021 from pydantic_ai .models .anthropic import AsyncAnthropicClient
2122 from pydantic_ai .providers import Provider
2223
2526
2627@overload
2728def gateway_provider (
28- upstream_provider : Literal ['openai' , 'openai-chat' , 'openai-responses' ],
29+ api_type : Literal ['chat' , 'responses' ],
30+ / ,
2931 * ,
3032 api_key : str | None = None ,
3133 base_url : str | None = None ,
@@ -35,7 +37,8 @@ def gateway_provider(
3537
3638@overload
3739def gateway_provider (
38- upstream_provider : Literal ['groq' ],
40+ api_type : Literal ['groq' ],
41+ / ,
3942 * ,
4043 api_key : str | None = None ,
4144 base_url : str | None = None ,
@@ -45,56 +48,63 @@ def gateway_provider(
4548
4649@overload
4750def gateway_provider (
48- upstream_provider : Literal ['google-vertex' ],
51+ api_type : Literal ['anthropic' ],
52+ / ,
4953 * ,
5054 api_key : str | None = None ,
5155 base_url : str | None = None ,
52- ) -> Provider [GoogleClient ]: ...
56+ http_client : httpx .AsyncClient | None = None ,
57+ ) -> Provider [AsyncAnthropicClient ]: ...
5358
5459
5560@overload
5661def gateway_provider (
57- upstream_provider : Literal ['anthropic' ],
62+ api_type : Literal ['converse' ],
63+ / ,
5864 * ,
5965 api_key : str | None = None ,
6066 base_url : str | None = None ,
61- ) -> Provider [AsyncAnthropicClient ]: ...
67+ ) -> Provider [BaseClient ]: ...
6268
6369
6470@overload
6571def gateway_provider (
66- upstream_provider : Literal ['bedrock' ],
72+ api_type : Literal ['gemini' ],
73+ / ,
6774 * ,
6875 api_key : str | None = None ,
6976 base_url : str | None = None ,
70- ) -> Provider [BaseClient ]: ...
77+ http_client : httpx .AsyncClient | None = None ,
78+ ) -> Provider [GoogleClient ]: ...
7179
7280
7381@overload
7482def gateway_provider (
75- upstream_provider : str ,
83+ api_type : str ,
84+ / ,
7685 * ,
7786 api_key : str | None = None ,
7887 base_url : str | None = None ,
7988) -> Provider [Any ]: ...
8089
8190
82- UpstreamProvider = Literal ['openai' , 'openai- chat' , 'openai- responses' , 'groq ' , 'google-vertex ' , 'anthropic' , 'bedrock ' ]
91+ APIType = Literal ['chat' , 'responses' , 'gemini ' , 'converse ' , 'anthropic' , 'groq ' ]
8392
8493
8594def gateway_provider (
86- upstream_provider : UpstreamProvider | str ,
95+ api_type : APIType | str ,
96+ / ,
8797 * ,
8898 # Every provider
8999 api_key : str | None = None ,
90100 base_url : str | None = None ,
91- # OpenAI, Groq & Anthropic
101+ # OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client.
92102 http_client : httpx .AsyncClient | None = None ,
93103) -> Provider [Any ]:
94104 """Create a new Gateway provider.
95105
96106 Args:
97- upstream_provider: The upstream provider to use.
107+ api_type: Determines the API type to use.
98108 api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
99109 environment variable will be used if available.
100110 base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
@@ -109,18 +119,18 @@ def gateway_provider(
109119 )
110120
111121 base_url = base_url or os .getenv ('PYDANTIC_AI_GATEWAY_BASE_URL' , GATEWAY_BASE_URL )
112- http_client = http_client or cached_async_http_client (provider = f'gateway/{ upstream_provider } ' )
122+ http_client = http_client or cached_async_http_client (provider = f'gateway/{ api_type } ' )
113123 http_client .event_hooks = {'request' : [_request_hook (api_key )]}
114124
115- if upstream_provider in ('openai' , 'openai- chat' , 'openai- responses' ):
125+ if api_type in ('chat' , 'responses' ):
116126 from .openai import OpenAIProvider
117127
118- return OpenAIProvider (api_key = api_key , base_url = _merge_url_path (base_url , 'openai' ), http_client = http_client )
119- elif upstream_provider == 'groq' :
128+ return OpenAIProvider (api_key = api_key , base_url = _merge_url_path (base_url , api_type ), http_client = http_client )
129+ elif api_type == 'groq' :
120130 from .groq import GroqProvider
121131
122132 return GroqProvider (api_key = api_key , base_url = _merge_url_path (base_url , 'groq' ), http_client = http_client )
123- elif upstream_provider == 'anthropic' :
133+ elif api_type == 'anthropic' :
124134 from anthropic import AsyncAnthropic
125135
126136 from .anthropic import AnthropicProvider
@@ -132,25 +142,25 @@ def gateway_provider(
132142 http_client = http_client ,
133143 )
134144 )
135- elif upstream_provider == 'bedrock ' :
145+ elif api_type == 'converse ' :
136146 from .bedrock import BedrockProvider
137147
138148 return BedrockProvider (
139149 api_key = api_key ,
140- base_url = _merge_url_path (base_url , 'bedrock' ),
150+ base_url = _merge_url_path (base_url , api_type ),
141151 region_name = 'pydantic-ai-gateway' , # Fake region name to avoid NoRegionError
142152 )
143- elif upstream_provider == 'google-vertex ' :
153+ elif api_type == 'gemini ' :
144154 from .google import GoogleProvider
145155
146156 return GoogleProvider (
147157 vertexai = True ,
148158 api_key = api_key ,
149- base_url = _merge_url_path (base_url , 'google-vertex ' ),
159+ base_url = _merge_url_path (base_url , 'gemini ' ),
150160 http_client = http_client ,
151161 )
152162 else :
153- raise UserError (f'Unknown upstream provider : { upstream_provider } ' )
163+ raise UserError (f'Unknown API type : { api_type } ' )
154164
155165
156166def _request_hook (api_key : str ) -> Callable [[httpx .Request ], Awaitable [httpx .Request ]]:
@@ -182,3 +192,33 @@ def _merge_url_path(base_url: str, path: str) -> str:
182192 path: The path to merge.
183193 """
184194 return base_url .rstrip ('/' ) + '/' + path .lstrip ('/' )
195+
196+
197+ def infer_gateway_model (api_type : APIType | str , * , model_name : str ) -> Model :
198+ """Infer the model class for a given API type."""
199+ if api_type == 'chat' :
200+ from pydantic_ai .models .openai import OpenAIChatModel
201+
202+ return OpenAIChatModel (model_name = model_name , provider = 'gateway' )
203+ elif api_type == 'groq' :
204+ from pydantic_ai .models .groq import GroqModel
205+
206+ return GroqModel (model_name = model_name , provider = 'gateway' )
207+ elif api_type == 'responses' :
208+ from pydantic_ai .models .openai import OpenAIResponsesModel
209+
210+ return OpenAIResponsesModel (model_name = model_name , provider = 'gateway' )
211+ elif api_type == 'gemini' :
212+ from pydantic_ai .models .google import GoogleModel
213+
214+ return GoogleModel (model_name = model_name , provider = 'gateway' )
215+ elif api_type == 'converse' :
216+ from pydantic_ai .models .bedrock import BedrockConverseModel
217+
218+ return BedrockConverseModel (model_name = model_name , provider = 'gateway' )
219+ elif api_type == 'anthropic' :
220+ from pydantic_ai .models .anthropic import AnthropicModel
221+
222+ return AnthropicModel (model_name = model_name , provider = 'gateway' )
223+ else :
224+ raise ValueError (f'Unknown API type: { api_type } ' ) # pragma: no cover
0 commit comments