Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"numpy>=2.2.6",
"openai>=1.81.0",
"pydantic>=2.11.4",
"pydantic-ai-slim[openai]>=1.39.0",
"pyreadline3>=3.5.4 ; sys_platform == 'win32'",
"python-dotenv>=1.1.0",
"tiktoken>=0.12.0",
Expand Down Expand Up @@ -85,11 +86,10 @@ dev = [
"google-auth-httplib2>=0.2.0",
"google-auth-oauthlib>=1.2.2",
"isort>=7.0.0",
"logfire>=4.1.0", # So 'make check' passes
"logfire>=4.1.0", # So 'make check' passes
"msgraph-sdk>=1.54.0",
"opentelemetry-instrumentation-httpx>=0.57b0",
"pydantic-ai-slim[openai]>=1.39.0",
"pyright>=1.1.408", # 407 has a regression
"pyright>=1.1.408", # 407 has a regression
"pytest>=8.3.5",
"pytest-asyncio>=0.26.0",
"pytest-mock>=3.14.0",
Expand Down
43 changes: 41 additions & 2 deletions src/typeagent/aitools/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
import os
from typing import Protocol, runtime_checkable

import numpy as np
from numpy.typing import NDArray
Expand All @@ -20,6 +21,39 @@
type NormalizedEmbeddings = NDArray[np.float32] # An array of embeddings


@runtime_checkable
class IEmbeddingModel(Protocol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

he Protocol bundles caching semantics (add_embedding, get_embedding with cache, _nocache variants) into the provider interface itself. This means every new provider implementation (e.g., Anthropic, Cohere) must re-implement the same caching boilerplate.
e.g. PydanticAIEmbeddingModel duplicates nearly identical caching logic from AsyncEmbeddingModel.

Consider splitting the interface:

A minimal IEmbedder Protocol with only get_embedding_nocache / get_embeddings_nocache + model_name + embedding_size
A shared CachingEmbeddingModel base class (or a decorator/wrapper) that adds the cache layer on top

"""Provider-agnostic interface for embedding models.

Implement this protocol to add support for a new embedding provider
(e.g. Anthropic, Gemini, local models). The existing AsyncEmbeddingModel
implements it for OpenAI and Azure OpenAI.
"""

model_name: str
embedding_size: int

def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None:
"""Cache an already-computed embedding under the given key."""
...

async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding:
"""Compute a single embedding without caching."""
...

async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings:
"""Compute embeddings for a batch of strings without caching."""
...

async def get_embedding(self, key: str) -> NormalizedEmbedding:
"""Retrieve a single embedding, using cache if available."""
...

async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings:
"""Retrieve embeddings for multiple keys, using cache if available."""
...


DEFAULT_MODEL_NAME = "text-embedding-ada-002"
DEFAULT_EMBEDDING_SIZE = 1536 # Default embedding size (required for ada-002)
DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING" # We support OpenAI and Azure OpenAI
Expand Down Expand Up @@ -60,6 +94,7 @@ def __init__(
model_name: str | None = None,
endpoint_envvar: str | None = None,
max_retries: int = DEFAULT_MAX_RETRIES,
use_azure: bool | None = None,
):
if model_name is None:
model_name = DEFAULT_MODEL_NAME
Expand Down Expand Up @@ -88,8 +123,12 @@ def __init__(
openai_api_key = os.getenv("OPENAI_API_KEY")
azure_api_key = os.getenv("AZURE_OPENAI_API_KEY")

# Prefer OpenAI if both are set, use Azure if only Azure is set
self.use_azure = bool(azure_api_key) and not bool(openai_api_key)
# Determine provider: explicit use_azure overrides auto-detection.
if use_azure is not None:
self.use_azure = use_azure
else:
# Prefer OpenAI if both are set, use Azure if only Azure is set
self.use_azure = bool(azure_api_key) and not bool(openai_api_key)

if endpoint_envvar is None:
# Check if OpenAI credentials are available, prefer OpenAI over Azure
Expand Down
292 changes: 292 additions & 0 deletions src/typeagent/aitools/model_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Provider-agnostic model configuration backed by pydantic_ai.

Create chat and embedding models from ``provider:model`` spec strings::

from typeagent.aitools.model_adapters import configure_models

chat, embedder = configure_models(
"openai:gpt-4o",
"openai:text-embedding-3-small",
)

The spec format is ``provider:model``, matching pydantic_ai conventions.
Provider wiring (API keys, endpoints, etc.) is handled by pydantic_ai's
model registry, which supports 25+ providers including ``openai``,
``azure``, ``anthropic``, ``google``, ``bedrock``, ``groq``, ``mistral``,
``ollama``, ``cohere``, and many more.

When a spec uses ``openai:`` as the provider and ``OPENAI_API_KEY`` is not
set, but ``AZURE_OPENAI_API_KEY`` is available, the provider is
automatically switched to Azure OpenAI.

See https://ai.pydantic.dev/models/ for all supported providers and their
required environment variables.
"""

import os

import numpy as np
from numpy.typing import NDArray

from pydantic_ai import Embedder as _PydanticAIEmbedder
from pydantic_ai.messages import (
ModelMessage,
ModelRequest,
SystemPromptPart,
TextPart,
UserPromptPart,
)
from pydantic_ai.models import infer_model, Model, ModelRequestParameters
import typechat

from .embeddings import IEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings

# ---------------------------------------------------------------------------
# Chat model adapter
# ---------------------------------------------------------------------------


class PydanticAIChatModel(typechat.TypeChatLanguageModel):
"""Adapter from :class:`pydantic_ai.models.Model` to TypeChat's
:class:`~typechat.TypeChatLanguageModel`.

This lets any pydantic_ai chat model (OpenAI, Anthropic, Google, …) be
used wherever TypeChat expects a ``TypeChatLanguageModel``.
"""

def __init__(self, model: Model) -> None:
self._model = model

async def complete(
self, prompt: str | list[typechat.PromptSection]
) -> typechat.Result[str]:
parts: list[SystemPromptPart | UserPromptPart] = []
if isinstance(prompt, str):
parts.append(UserPromptPart(content=prompt))
else:
for section in prompt:
if section["role"] == "system":
parts.append(SystemPromptPart(content=section["content"]))
else:
parts.append(UserPromptPart(content=section["content"]))

messages: list[ModelMessage] = [ModelRequest(parts=parts)]
params = ModelRequestParameters()

response = await self._model.request(messages, None, params)
text_parts = [p.content for p in response.parts if isinstance(p, TextPart)]
if text_parts:
return typechat.Success("".join(text_parts))
return typechat.Failure("No text content in model response")


# ---------------------------------------------------------------------------
# Embedding model adapter
# ---------------------------------------------------------------------------


class PydanticAIEmbeddingModel(IEmbeddingModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AsyncEmbeddingModel (in aitools/embeddings.py) does NOT inherit from IEmbeddingModel, so kind of asymetric ?
furthermore: Is subclassing a Protocol directly really the intended pattern ?

"""Adapter from :class:`pydantic_ai.Embedder` to :class:`IEmbeddingModel`.

This lets any pydantic_ai embedding provider (OpenAI, Cohere, Google, …)
be used wherever the codebase expects an ``IEmbeddingModel``, including
:class:`~typeagent.aitools.vectorbase.VectorBase` and
:class:`~typeagent.knowpro.convsettings.ConversationSettings`.

If *embedding_size* is not given, it is probed automatically by making a
single embedding call.
"""

model_name: str
embedding_size: int

def __init__(
self,
embedder: _PydanticAIEmbedder,
model_name: str,
embedding_size: int = 0,
) -> None:
self._embedder = embedder
self.model_name = model_name
self.embedding_size = embedding_size
self._cache: dict[str, NormalizedEmbedding] = {}

def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None:
self._cache[key] = embedding

async def _probe_embedding_size(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why probe the embedding size ?
we should specify it, or reject the call

"""Discover embedding_size by making a single API call."""
result = await self._embedder.embed(["probe"], input_type="document")
self.embedding_size = len(result.embeddings[0])

async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding:
result = await self._embedder.embed([input], input_type="document")
embedding: NDArray[np.float32] = np.array(
result.embeddings[0], dtype=np.float32
)
if self.embedding_size == 0:
self.embedding_size = len(embedding)
norm = float(np.linalg.norm(embedding))
if norm > 0:
embedding = (embedding / norm).astype(np.float32)
return embedding

async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings:
if not input:
if self.embedding_size == 0:
await self._probe_embedding_size()
Comment on lines +139 to +140
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same where. we should specify the embedding size as mandatory parameter

return np.empty((0, self.embedding_size), dtype=np.float32)
result = await self._embedder.embed(input, input_type="document")
embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32)
if self.embedding_size == 0:
self.embedding_size = embeddings.shape[1]
norms = np.linalg.norm(embeddings, axis=1, keepdims=True).astype(np.float32)
norms = np.where(norms > 0, norms, np.float32(1.0))
embeddings = (embeddings / norms).astype(np.float32)
return embeddings

async def get_embedding(self, key: str) -> NormalizedEmbedding:
cached = self._cache.get(key)
if cached is not None:
return cached
embedding = await self.get_embedding_nocache(key)
self._cache[key] = embedding
return embedding

async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings:
missing_keys = [k for k in keys if k not in self._cache]
if missing_keys:
fresh = await self.get_embeddings_nocache(missing_keys)
for i, k in enumerate(missing_keys):
self._cache[k] = fresh[i]
return np.array([self._cache[k] for k in keys], dtype=np.float32)


# ---------------------------------------------------------------------------
# Provider auto-detection
# ---------------------------------------------------------------------------


def _needs_azure_fallback(provider: str) -> bool:
"""Return True if *provider* is ``openai`` but only Azure credentials exist."""
return (
provider == "openai"
and not os.getenv("OPENAI_API_KEY")
and bool(os.getenv("AZURE_OPENAI_API_KEY"))
)


def _make_azure_provider():
"""Create a :class:`pydantic_ai.providers.azure.AzureProvider`."""
from pydantic_ai.providers.azure import AzureProvider

from .utils import get_azure_api_key, parse_azure_endpoint

raw_key = os.environ["AZURE_OPENAI_API_KEY"]
api_key = get_azure_api_key(raw_key)
azure_endpoint, api_version = parse_azure_endpoint("AZURE_OPENAI_ENDPOINT")
return AzureProvider(
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
)


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


def create_chat_model(
model_spec: str,
) -> PydanticAIChatModel:
"""Create a chat model from a ``provider:model`` spec.

Delegates to :func:`pydantic_ai.models.infer_model` for provider wiring.
If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but
``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically.

Examples::

model = create_chat_model("openai:gpt-4o")
model = create_chat_model("anthropic:claude-sonnet-4-20250514")
model = create_chat_model("google:gemini-2.0-flash")
"""
provider, _, model_name = model_spec.partition(":")
if _needs_azure_fallback(provider):
from pydantic_ai.models.openai import OpenAIChatModel

model = OpenAIChatModel(model_name, provider=_make_azure_provider())
else:
model = infer_model(model_spec)
return PydanticAIChatModel(model)


DEFAULT_EMBEDDING_SPEC = "openai:text-embedding-3-small"


def create_embedding_model(
model_spec: str | None = None,
*,
embedding_size: int = 0,
) -> PydanticAIEmbeddingModel:
"""Create an embedding model from a ``provider:model`` spec.

Delegates to :class:`pydantic_ai.Embedder` for provider wiring.
If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but
``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically.

If *model_spec* is ``None``, :data:`DEFAULT_EMBEDDING_SPEC` is used.
If *embedding_size* is not given, it will be probed automatically
on the first embedding call.

Examples::

model = create_embedding_model("openai:text-embedding-3-small")
model = create_embedding_model("cohere:embed-english-v3.0")
model = create_embedding_model("google:text-embedding-004")
"""
if model_spec is None:
model_spec = DEFAULT_EMBEDDING_SPEC
provider, _, model_name = model_spec.partition(":")
if not model_name:
model_name = provider # No colon in spec
if _needs_azure_fallback(provider):
from pydantic_ai.embeddings.openai import OpenAIEmbeddingModel

embedding_model = OpenAIEmbeddingModel(
model_name, provider=_make_azure_provider()
)
embedder = _PydanticAIEmbedder(embedding_model)
else:
embedder = _PydanticAIEmbedder(model_spec)
return PydanticAIEmbeddingModel(embedder, model_name, embedding_size)


def configure_models(
chat_model_spec: str,
embedding_model_spec: str,
*,
embedding_size: int = 0,
) -> tuple[PydanticAIChatModel, PydanticAIEmbeddingModel]:
"""Configure both a chat model and an embedding model at once.

Delegates to pydantic_ai's model registry for provider wiring.

Example::

chat, embedder = configure_models(
"openai:gpt-4o",
"openai:text-embedding-3-small",
)

settings = ConversationSettings(model=embedder)
extractor = KnowledgeExtractor(model=chat)
"""
return (
create_chat_model(chat_model_spec),
create_embedding_model(embedding_model_spec, embedding_size=embedding_size),
)
Loading
Loading