Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
WebSearchTool,
WebSearchUserLocation,
)
from .embeddings import (
Embedder,
)
from .exceptions import (
AgentRunError,
ApprovalRequired,
Expand Down Expand Up @@ -119,6 +122,8 @@
'UserPromptNode',
'capture_run_messages',
'InstrumentationSettings',
# embeddings
'Embedder',
# exceptions
'AgentRunError',
'CallDeferred',
Expand Down
137 changes: 137 additions & 0 deletions pydantic_ai_slim/pydantic_ai/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from typing import Literal, overload

from typing_extensions import TypeAliasType

from pydantic_ai import _utils
from pydantic_ai.embeddings.embedding_model import EmbeddingModel
from pydantic_ai.embeddings.settings import EmbeddingSettings, merge_embedding_settings
from pydantic_ai.exceptions import UserError
from pydantic_ai.models.instrumented import InstrumentationSettings
from pydantic_ai.providers import infer_provider

KnownEmbeddingModelName = TypeAliasType(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test like this one to verify this is up to date:

def test_known_model_names(): # pragma: lax no cover

'KnownEmbeddingModelName',
Literal[
'openai:text-embedding-ada-002',
'openai:text-embedding-3-small',
'openai:text-embedding-3-largecohere:embed-v4.0',
],
)
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].

`KnownModelName` is provided as a concise way to specify a model.
"""


def infer_model(model: EmbeddingModel | KnownEmbeddingModelName | str) -> EmbeddingModel:
"""Infer the model from the name."""
if isinstance(model, EmbeddingModel):
return model

try:
provider_name, model_name = model.split(':', maxsplit=1)
except ValueError as e:
raise ValueError('You must provide a provider prefix when specifying an embedding model name') from e

provider = infer_provider(provider_name)

model_kind = provider_name
if model_kind.startswith('gateway/'):
model_kind = provider_name.removeprefix('gateway/')

# TODO: extend the following list for other providers as appropriate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have to check which of the OpenAI-compatible APIs also support embeddings

if model_kind in ('openai',):
model_kind = 'openai'

if model_kind == 'openai':
from .openai import OpenAIEmbeddingModel

return OpenAIEmbeddingModel(model_name, provider=provider)
elif model_kind == 'cohere':
from .cohere import CohereEmbeddingModel

return CohereEmbeddingModel(model_name, provider=provider)
else:
raise UserError(f'Unknown embeddings model: {model}') # pragma: no cover
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/ggozad/haiku.rag/tree/main/src/haiku/rag/embeddings has Ollama, vLLM and VoyageAI, which would be worth adding as well



@dataclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@dataclass
@dataclass(init=False)

class Embedder:
instrument: InstrumentationSettings | bool | None
"""Options to automatically instrument with OpenTelemetry."""

def __init__(
self,
model: EmbeddingModel | KnownEmbeddingModelName | str,
*,
settings: EmbeddingSettings | None = None,
defer_model_check: bool = True,
# TODO: Figure out instrumentation later..
instrument: InstrumentationSettings | bool | None = None,
) -> None:
self._model = model if defer_model_check else infer_model(model)
self._settings = settings
self._instrument = instrument

self._override_model: ContextVar[EmbeddingModel | None] = ContextVar('_override_model', default=None)

@property
def model(self) -> EmbeddingModel | KnownEmbeddingModelName | str:
return self._model

@contextmanager
def override(
self,
*,
model: EmbeddingModel | KnownEmbeddingModelName | str | _utils.Unset = _utils.UNSET,
) -> Iterator[None]:
if _utils.is_set(model):
model_token = self._override_model.set(infer_model(model))
else:
model_token = None

try:
yield
finally:
if model_token is not None:
self._override_model.reset(model_token)

@overload
async def embed(self, documents: str, *, settings: EmbeddingSettings | None = None) -> list[float]:
pass

@overload
async def embed(self, documents: Sequence[str], *, settings: EmbeddingSettings | None = None) -> list[list[float]]:
pass

async def embed(
self, documents: str | Sequence[str], *, settings: EmbeddingSettings | None = None
) -> list[float] | list[list[float]]:
model = self._get_model()
settings = merge_embedding_settings(self._settings, settings)
return await model.embed(documents, settings=settings)

def _get_model(self) -> EmbeddingModel:
"""Create a model configured for this agent.

Returns:
The embedding model to use
"""
model_: EmbeddingModel
if some_model := self._override_model.get():
model_ = some_model
else:
model_ = self._model = infer_model(self.model)

# TODO: Port the instrumentation logic from Model once we settle on an embeddings API
# instrument = self.instrument
# if instrument is None:
# instrument = Agent._instrument_default
#
# return instrument_model(model_, instrument)

return model_
104 changes: 104 additions & 0 deletions pydantic_ai_slim/pydantic_ai/embeddings/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Literal, cast, overload

from pydantic_ai.embeddings.embedding_model import EmbeddingModel
from pydantic_ai.embeddings.settings import EmbeddingSettings
from pydantic_ai.providers import Provider, infer_provider

from .settings import merge_embedding_settings

try:
from cohere import AsyncClientV2
except ImportError as _import_error:
raise ImportError(
'Please install `cohere` to use the Cohere embeddings model, '
'you can use the `cohere` optional group — `pip install "pydantic-ai-slim[cohere]"`'
) from _import_error

LatestCohereEmbeddingModelNames = Literal[
'cohere:embed-v4.0',
# TODO: Add the others
]
"""Latest Cohere embeddings models."""

CohereEmbeddingModelName = str | LatestCohereEmbeddingModelNames
"""Possible Cohere embeddings model names."""


@dataclass(init=False)
class CohereEmbeddingModel(EmbeddingModel):
_model_name: CohereEmbeddingModelName = field(repr=False)
_provider: Provider[AsyncClientV2] = field(repr=False)

def __init__(
self,
model_name: CohereEmbeddingModelName,
*,
provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
settings: EmbeddingSettings | None = None,
):
"""Initialize an Cohere model.

Args:
model_name: The name of the Cohere model to use. List of model names
available [here](https://docs.cohere.com/docs/models#command).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

provider: The provider to use for authentication and API access. Can be either the string
'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
created using the other parameters.
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be removed

settings: Model-specific settings that will be used as defaults for this model.
"""
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider(provider)
self._provider = provider
self._client = provider.client

super().__init__(settings=settings)

@property
def base_url(self) -> str:
"""The base URL for the provider API, if available."""
return self._provider.base_url

@property
def model_name(self) -> CohereEmbeddingModelName:
"""The embedding model name."""
return self._model_name

@property
def system(self) -> str:
"""The embedding model provider."""
return self._provider.name

@overload
async def embed(self, documents: str, *, settings: EmbeddingSettings | None = None) -> list[float]:
pass

@overload
async def embed(self, documents: Sequence[str], *, settings: EmbeddingSettings | None = None) -> list[list[float]]:
pass

async def embed(
self, documents: Sequence[str], *, settings: EmbeddingSettings | None = None
) -> list[float] | list[list[float]]:
input_is_string = isinstance(documents, str)
if input_is_string:
documents = [documents]

settings = merge_embedding_settings(self._settings, settings) or {}
response = await self._client.embed(
model=self.model_name,
input_type=settings.get('input_type', 'search_document'),
texts=cast(Sequence[str], documents),
output_dimension=settings.get('output_dimension'),
)
embeddings = response.embeddings.float_
assert embeddings is not None, 'This is a bug in cohere?'

if input_is_string:
return embeddings[0]

return embeddings
55 changes: 55 additions & 0 deletions pydantic_ai_slim/pydantic_ai/embeddings/embedding_model.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to move this to __init__

Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import overload

from pydantic_ai.embeddings.settings import EmbeddingSettings


class EmbeddingModel(ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EmbeddingsModel, as the module is called embeddings?

"""Abstract class for a model."""

_settings: EmbeddingSettings | None = None

def __init__(
self,
*,
settings: EmbeddingSettings | None = None,
) -> None:
"""Initialize the model with optional settings and profile.

Args:
settings: Model-specific settings that will be used as defaults for this model.
profile: The model profile to use.
"""
self._settings = settings

@property
def settings(self) -> EmbeddingSettings | None:
"""Get the model settings."""
return self._settings

@property
@abstractmethod
def model_name(self) -> str:
"""The model name."""
raise NotImplementedError()

# TODO: Add system?

@property
def base_url(self) -> str | None:
"""The base URL for the provider API, if available."""
return None

@overload
async def embed(self, documents: str, *, settings: EmbeddingSettings | None = None) -> list[float]:
pass

@overload
async def embed(self, documents: Sequence[str], *, settings: EmbeddingSettings | None = None) -> list[list[float]]:
pass

async def embed(
self, documents: str | Sequence[str], *, settings: EmbeddingSettings | None = None
) -> list[float] | list[list[float]]:
raise NotImplementedError
Loading
Loading