-
Notifications
You must be signed in to change notification settings - Fork 50
Make chat and embed interfaces provider-agnostic using pydantic_ai #200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
74097cf
d59d7b6
ff733f5
1015737
4bd1387
60aa403
067f3b9
17b959f
76621b4
6f1286f
83d6f0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,292 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
|
|
||
| """Provider-agnostic model configuration backed by pydantic_ai. | ||
|
|
||
| Create chat and embedding models from ``provider:model`` spec strings:: | ||
|
|
||
| from typeagent.aitools.model_adapters import configure_models | ||
|
|
||
| chat, embedder = configure_models( | ||
| "openai:gpt-4o", | ||
| "openai:text-embedding-3-small", | ||
| ) | ||
|
|
||
| The spec format is ``provider:model``, matching pydantic_ai conventions. | ||
| Provider wiring (API keys, endpoints, etc.) is handled by pydantic_ai's | ||
| model registry, which supports 25+ providers including ``openai``, | ||
| ``azure``, ``anthropic``, ``google``, ``bedrock``, ``groq``, ``mistral``, | ||
| ``ollama``, ``cohere``, and many more. | ||
|
|
||
| When a spec uses ``openai:`` as the provider and ``OPENAI_API_KEY`` is not | ||
| set, but ``AZURE_OPENAI_API_KEY`` is available, the provider is | ||
| automatically switched to Azure OpenAI. | ||
|
|
||
| See https://ai.pydantic.dev/models/ for all supported providers and their | ||
| required environment variables. | ||
| """ | ||
|
|
||
| import os | ||
|
|
||
| import numpy as np | ||
| from numpy.typing import NDArray | ||
|
|
||
| from pydantic_ai import Embedder as _PydanticAIEmbedder | ||
| from pydantic_ai.messages import ( | ||
| ModelMessage, | ||
| ModelRequest, | ||
| SystemPromptPart, | ||
| TextPart, | ||
| UserPromptPart, | ||
| ) | ||
| from pydantic_ai.models import infer_model, Model, ModelRequestParameters | ||
| import typechat | ||
|
|
||
| from .embeddings import IEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Chat model adapter | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class PydanticAIChatModel(typechat.TypeChatLanguageModel): | ||
| """Adapter from :class:`pydantic_ai.models.Model` to TypeChat's | ||
| :class:`~typechat.TypeChatLanguageModel`. | ||
|
|
||
| This lets any pydantic_ai chat model (OpenAI, Anthropic, Google, …) be | ||
| used wherever TypeChat expects a ``TypeChatLanguageModel``. | ||
| """ | ||
|
|
||
| def __init__(self, model: Model) -> None: | ||
| self._model = model | ||
|
|
||
| async def complete( | ||
| self, prompt: str | list[typechat.PromptSection] | ||
| ) -> typechat.Result[str]: | ||
| parts: list[SystemPromptPart | UserPromptPart] = [] | ||
| if isinstance(prompt, str): | ||
| parts.append(UserPromptPart(content=prompt)) | ||
| else: | ||
| for section in prompt: | ||
| if section["role"] == "system": | ||
| parts.append(SystemPromptPart(content=section["content"])) | ||
| else: | ||
| parts.append(UserPromptPart(content=section["content"])) | ||
|
|
||
| messages: list[ModelMessage] = [ModelRequest(parts=parts)] | ||
| params = ModelRequestParameters() | ||
|
|
||
| response = await self._model.request(messages, None, params) | ||
| text_parts = [p.content for p in response.parts if isinstance(p, TextPart)] | ||
| if text_parts: | ||
| return typechat.Success("".join(text_parts)) | ||
| return typechat.Failure("No text content in model response") | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Embedding model adapter | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class PydanticAIEmbeddingModel(IEmbeddingModel): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AsyncEmbeddingModel (in aitools/embeddings.py) does NOT inherit from IEmbeddingModel, so kind of asymetric ? |
||
| """Adapter from :class:`pydantic_ai.Embedder` to :class:`IEmbeddingModel`. | ||
|
|
||
| This lets any pydantic_ai embedding provider (OpenAI, Cohere, Google, …) | ||
| be used wherever the codebase expects an ``IEmbeddingModel``, including | ||
| :class:`~typeagent.aitools.vectorbase.VectorBase` and | ||
| :class:`~typeagent.knowpro.convsettings.ConversationSettings`. | ||
|
|
||
| If *embedding_size* is not given, it is probed automatically by making a | ||
| single embedding call. | ||
| """ | ||
|
|
||
| model_name: str | ||
| embedding_size: int | ||
|
|
||
| def __init__( | ||
| self, | ||
| embedder: _PydanticAIEmbedder, | ||
| model_name: str, | ||
| embedding_size: int = 0, | ||
| ) -> None: | ||
| self._embedder = embedder | ||
| self.model_name = model_name | ||
| self.embedding_size = embedding_size | ||
| self._cache: dict[str, NormalizedEmbedding] = {} | ||
|
|
||
| def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: | ||
| self._cache[key] = embedding | ||
|
|
||
| async def _probe_embedding_size(self) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why probe the embedding size ? |
||
| """Discover embedding_size by making a single API call.""" | ||
| result = await self._embedder.embed(["probe"], input_type="document") | ||
| self.embedding_size = len(result.embeddings[0]) | ||
|
|
||
| async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: | ||
| result = await self._embedder.embed([input], input_type="document") | ||
| embedding: NDArray[np.float32] = np.array( | ||
| result.embeddings[0], dtype=np.float32 | ||
| ) | ||
| if self.embedding_size == 0: | ||
| self.embedding_size = len(embedding) | ||
| norm = float(np.linalg.norm(embedding)) | ||
| if norm > 0: | ||
| embedding = (embedding / norm).astype(np.float32) | ||
| return embedding | ||
|
|
||
| async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: | ||
| if not input: | ||
| if self.embedding_size == 0: | ||
| await self._probe_embedding_size() | ||
|
Comment on lines
+139
to
+140
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same where. we should specify the embedding size as mandatory parameter |
||
| return np.empty((0, self.embedding_size), dtype=np.float32) | ||
| result = await self._embedder.embed(input, input_type="document") | ||
| embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32) | ||
| if self.embedding_size == 0: | ||
| self.embedding_size = embeddings.shape[1] | ||
| norms = np.linalg.norm(embeddings, axis=1, keepdims=True).astype(np.float32) | ||
| norms = np.where(norms > 0, norms, np.float32(1.0)) | ||
| embeddings = (embeddings / norms).astype(np.float32) | ||
| return embeddings | ||
|
|
||
| async def get_embedding(self, key: str) -> NormalizedEmbedding: | ||
| cached = self._cache.get(key) | ||
| if cached is not None: | ||
| return cached | ||
| embedding = await self.get_embedding_nocache(key) | ||
| self._cache[key] = embedding | ||
| return embedding | ||
|
|
||
| async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: | ||
| missing_keys = [k for k in keys if k not in self._cache] | ||
| if missing_keys: | ||
| fresh = await self.get_embeddings_nocache(missing_keys) | ||
| for i, k in enumerate(missing_keys): | ||
| self._cache[k] = fresh[i] | ||
| return np.array([self._cache[k] for k in keys], dtype=np.float32) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Provider auto-detection | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _needs_azure_fallback(provider: str) -> bool: | ||
| """Return True if *provider* is ``openai`` but only Azure credentials exist.""" | ||
| return ( | ||
| provider == "openai" | ||
| and not os.getenv("OPENAI_API_KEY") | ||
| and bool(os.getenv("AZURE_OPENAI_API_KEY")) | ||
| ) | ||
|
|
||
|
|
||
| def _make_azure_provider(): | ||
| """Create a :class:`pydantic_ai.providers.azure.AzureProvider`.""" | ||
| from pydantic_ai.providers.azure import AzureProvider | ||
|
|
||
| from .utils import get_azure_api_key, parse_azure_endpoint | ||
|
|
||
| raw_key = os.environ["AZURE_OPENAI_API_KEY"] | ||
| api_key = get_azure_api_key(raw_key) | ||
| azure_endpoint, api_version = parse_azure_endpoint("AZURE_OPENAI_ENDPOINT") | ||
| return AzureProvider( | ||
| azure_endpoint=azure_endpoint, | ||
| api_version=api_version, | ||
| api_key=api_key, | ||
| ) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Public API | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def create_chat_model( | ||
| model_spec: str, | ||
| ) -> PydanticAIChatModel: | ||
| """Create a chat model from a ``provider:model`` spec. | ||
|
|
||
| Delegates to :func:`pydantic_ai.models.infer_model` for provider wiring. | ||
| If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but | ||
| ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. | ||
|
|
||
| Examples:: | ||
|
|
||
| model = create_chat_model("openai:gpt-4o") | ||
| model = create_chat_model("anthropic:claude-sonnet-4-20250514") | ||
| model = create_chat_model("google:gemini-2.0-flash") | ||
| """ | ||
| provider, _, model_name = model_spec.partition(":") | ||
| if _needs_azure_fallback(provider): | ||
| from pydantic_ai.models.openai import OpenAIChatModel | ||
|
|
||
| model = OpenAIChatModel(model_name, provider=_make_azure_provider()) | ||
| else: | ||
| model = infer_model(model_spec) | ||
| return PydanticAIChatModel(model) | ||
|
|
||
|
|
||
| DEFAULT_EMBEDDING_SPEC = "openai:text-embedding-3-small" | ||
|
|
||
|
|
||
| def create_embedding_model( | ||
| model_spec: str | None = None, | ||
| *, | ||
| embedding_size: int = 0, | ||
| ) -> PydanticAIEmbeddingModel: | ||
| """Create an embedding model from a ``provider:model`` spec. | ||
|
|
||
| Delegates to :class:`pydantic_ai.Embedder` for provider wiring. | ||
| If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but | ||
| ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. | ||
|
|
||
| If *model_spec* is ``None``, :data:`DEFAULT_EMBEDDING_SPEC` is used. | ||
| If *embedding_size* is not given, it will be probed automatically | ||
| on the first embedding call. | ||
|
|
||
| Examples:: | ||
|
|
||
| model = create_embedding_model("openai:text-embedding-3-small") | ||
| model = create_embedding_model("cohere:embed-english-v3.0") | ||
| model = create_embedding_model("google:text-embedding-004") | ||
| """ | ||
| if model_spec is None: | ||
| model_spec = DEFAULT_EMBEDDING_SPEC | ||
| provider, _, model_name = model_spec.partition(":") | ||
| if not model_name: | ||
| model_name = provider # No colon in spec | ||
| if _needs_azure_fallback(provider): | ||
| from pydantic_ai.embeddings.openai import OpenAIEmbeddingModel | ||
|
|
||
| embedding_model = OpenAIEmbeddingModel( | ||
| model_name, provider=_make_azure_provider() | ||
| ) | ||
| embedder = _PydanticAIEmbedder(embedding_model) | ||
| else: | ||
| embedder = _PydanticAIEmbedder(model_spec) | ||
| return PydanticAIEmbeddingModel(embedder, model_name, embedding_size) | ||
|
|
||
|
|
||
| def configure_models( | ||
| chat_model_spec: str, | ||
| embedding_model_spec: str, | ||
| *, | ||
| embedding_size: int = 0, | ||
| ) -> tuple[PydanticAIChatModel, PydanticAIEmbeddingModel]: | ||
| """Configure both a chat model and an embedding model at once. | ||
|
|
||
| Delegates to pydantic_ai's model registry for provider wiring. | ||
|
|
||
| Example:: | ||
|
|
||
| chat, embedder = configure_models( | ||
| "openai:gpt-4o", | ||
| "openai:text-embedding-3-small", | ||
| ) | ||
|
|
||
| settings = ConversationSettings(model=embedder) | ||
| extractor = KnowledgeExtractor(model=chat) | ||
| """ | ||
| return ( | ||
| create_chat_model(chat_model_spec), | ||
| create_embedding_model(embedding_model_spec, embedding_size=embedding_size), | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
he Protocol bundles caching semantics (add_embedding, get_embedding with cache, _nocache variants) into the provider interface itself. This means every new provider implementation (e.g., Anthropic, Cohere) must re-implement the same caching boilerplate.
e.g. PydanticAIEmbeddingModel duplicates nearly identical caching logic from AsyncEmbeddingModel.
Consider splitting the interface:
A minimal IEmbedder Protocol with only get_embedding_nocache / get_embeddings_nocache + model_name + embedding_size
A shared CachingEmbeddingModel base class (or a decorator/wrapper) that adds the cache layer on top