Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion examples/ex_openai.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import os

import logfire
from devtools import debug
from openai import OpenAI

logfire.configure()
logfire.instrument_httpx(capture_all=True)


GATEWAY_API_KEY = os.getenv('GATEWAY_API_KEY')
assert GATEWAY_API_KEY, 'GATEWAY_API_KEY is not set'

client = OpenAI(
api_key='VOE4JMpVGr71RgvEEidPCXd4ov42L24ODw9q5RI7uYc',
api_key=GATEWAY_API_KEY,
base_url='http://localhost:8787/openai',
# base_url='https://pydantic-ai-gateway.pydantic.workers.dev/openai',
)
Expand Down
49 changes: 49 additions & 0 deletions examples/pai_anthropic_vertex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
from datetime import date

import logfire
from anthropic import AnthropicVertex
from google.auth.api_key import Credentials
from pydantic import BaseModel, field_validator
from pydantic_ai import Agent
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.providers.anthropic import AnthropicProvider

logfire.configure(service_name='testing')
logfire.instrument_pydantic_ai()

GATEWAY_API_KEY = os.getenv('GATEWAY_API_KEY')
assert GATEWAY_API_KEY, 'GATEWAY_API_KEY is not set'


class Person(BaseModel, use_attribute_docstrings=True):
name: str
"""The name of the person."""
dob: date
"""The date of birth of the person. MUST BE A VALID ISO 8601 date."""
city: str
"""The city where the person lives."""

@field_validator('dob')
def validate_dob(cls, v: date) -> date:
if v >= date(1900, 1, 1):
raise ValueError('The person must be born in the 19th century')
return v


client = AnthropicVertex(
base_url='http://localhost:8787/google-vertex',
region='unset',
project_id='unset',
credentials=Credentials(token=GATEWAY_API_KEY),
)
provider = AnthropicProvider(anthropic_client=client)
model = AnthropicModel('claude-sonnet-4-20250514', provider=provider)

person_agent = Agent(
model=model,
output_type=Person,
instructions='Extract information about the person',
)
result = person_agent.run_sync("Samuel lived in London and was born on Jan 28th '87")
print(repr(result.output))
8 changes: 6 additions & 2 deletions gateway/src/providers/google/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export async function authToken(credentials: string, kv: KVNamespace): Promise<s
return token
}

function getServiceAccount(credentials: string): ServiceAccount {
export function getServiceAccount(credentials: string): ServiceAccount {
let sa
try {
sa = JSON.parse(credentials) as ServiceAccount
Expand All @@ -27,12 +27,16 @@ function getServiceAccount(credentials: string): ServiceAccount {
if (typeof sa.private_key !== 'string') {
throw new ResponseError(400, `"private_key" should be a string, not ${typeof sa.private_key}`)
}
return { client_email: sa.client_email, private_key: sa.private_key }
if (typeof sa.project_id !== 'string') {
throw new ResponseError(400, `"project_id" should be a string, not ${typeof sa.project_id}`)
}
return { client_email: sa.client_email, private_key: sa.private_key, project_id: sa.project_id }
}

interface ServiceAccount {
client_email: string
private_key: string
project_id: string
}

const encoder = new TextEncoder()
Expand Down
30 changes: 23 additions & 7 deletions gateway/src/providers/google/index.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import { DefaultProviderProxy, JsonData, isMapping } from '../default'
import { authToken } from './auth'
import { authToken, getServiceAccount } from './auth'
import { otelEvents, GoogleRequest, GenerateContentResponse } from './otel'

export class GoogleVertexProvider extends DefaultProviderProxy {
protected usageField = 'usageMetadata'

url() {
if (this.providerProxy.baseUrl) {
const extra = this.restOfPath
// I think this regex is for GLA aka the google developer API
.replace(/^v1beta\/models\//, '')
// this is for requests expecting google vertex
.replace(/^v1beta1\/publishers\/google\/models\//, '')
return `${this.providerProxy.baseUrl}/${extra}`
// Extract project ID from credentials
const projectId = getServiceAccount(this.providerProxy.credentials).project_id

// Extract location from baseUrl (e.g., us-central1 from https://us-central1-aiplatform.googleapis.com)
// If no location is found, use "global"
const locationMatch = /https:\/\/(.+)-aiplatform\.googleapis\.com/.exec(this.providerProxy.baseUrl)
const location = locationMatch ? locationMatch[1] : 'global'

// Transform the path to inject correct project and location
const extra = transformPath(this.restOfPath, projectId, location!)
const finalUrl = `${this.providerProxy.baseUrl}/${extra}`
console.log('Final URL:', finalUrl)
return finalUrl
} else {
return { error: 'baseUrl is required for the Google Provider' }
}
Expand Down Expand Up @@ -47,3 +54,12 @@ export class GoogleVertexProvider extends DefaultProviderProxy {
return isMapping(responseBody) && typeof responseBody.responseId === 'string' ? responseBody.responseId : undefined
}
}

function transformPath(restOfPath: string, projectId: string, location: string): string {
return restOfPath
.replace(
/^v1beta1\/publishers\/google\/models/,
`v1beta1/projects/${projectId}/locations/${location}/publishers/google/models`,
)
.replace(/^projects\/unset\/locations\/unset/, `projects/${projectId}/locations/${location}`)
}
8 changes: 3 additions & 5 deletions proxy-vcr/proxy_vcr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@ async def proxy(request: Request) -> JSONResponse:
with vcr.use_cassette(f'{body_hash}.yaml'): # type: ignore[reportUnknownReturnType]
headers = {'Authorization': auth_header, 'content-type': 'application/json'}
response = await client.post(url, content=body, headers=headers)
return JSONResponse(response.json(), status_code=response.status_code)
elif request.url.path.startswith('/groq'):
client = cast(httpx.AsyncClient, request.scope['state']['httpx_client'])
url = GROQ_BASE_URL + request.url.path[len('/groq') :]
with vcr.use_cassette(f'{body_hash}.yaml'): # type: ignore[reportUnknownReturnType]
headers = {'Authorization': auth_header, 'content-type': 'application/json'}
response = await client.post(url, content=body, headers=headers)
return JSONResponse(response.json(), status_code=response.status_code)
elif request.url.path.startswith('/anthropic'):
client = cast(httpx.AsyncClient, request.scope['state']['httpx_client'])
url = ANTHROPIC_BASE_URL + request.url.path[len('/anthropic') :]
Expand All @@ -70,9 +68,9 @@ async def proxy(request: Request) -> JSONResponse:
'anthropic-version': request.headers.get('anthropic-version', '2023-06-01'),
}
response = await client.post(url, content=body, headers=headers)
return JSONResponse(response.json(), status_code=response.status_code)
raise HTTPException(status_code=400, detail='Invalid user agent')
# raise HTTPException(status_code=404, detail=f'Path {request.url.path} not supported')
else:
raise HTTPException(status_code=404, detail=f'Path {request.url.path} not supported')
return JSONResponse(response.json(), status_code=response.status_code)


async def health_check(_: Request) -> Response:
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dev = ["pyright>=1.1.341", "ruff>=0.12.8"]
[tool.uv.sources]
proxy-vcr = { workspace = true }
examples = { workspace = true }
pydantic-ai = { git = "https://github.com/pydantic/pydantic-ai.git", branch = "support-anthropic-gateway" }

[tool.uv.workspace]
members = ["proxy-vcr", "examples"]
Expand Down
40 changes: 28 additions & 12 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading