diff --git a/libs/partners/mistralai/langchain_mistralai/embeddings.py b/libs/partners/mistralai/langchain_mistralai/embeddings.py index 5fd0d6bb62cea..f8fd08dc72d74 100644 --- a/libs/partners/mistralai/langchain_mistralai/embeddings.py +++ b/libs/partners/mistralai/langchain_mistralai/embeddings.py @@ -1,7 +1,7 @@ import asyncio import logging import warnings -from collections.abc import Iterable +from collections.abc import Callable, Iterable import httpx from httpx import Response @@ -57,6 +57,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings): api_key: The API key for the MistralAI API. If not provided, it will be read from the environment variable `MISTRAL_API_KEY`. + max_concurrent_requests: int max_retries: The number of times to retry a request if it fails. timeout: @@ -133,9 +134,9 @@ class MistralAIEmbeddings(BaseModel, Embeddings): default_factory=secret_from_env("MISTRAL_API_KEY", default=""), ) endpoint: str = "https://api.mistral.ai/v1/" - max_retries: int = 5 + max_retries: int | None = 5 timeout: int = 120 - wait_time: int = 30 + wait_time: int | None = 30 max_concurrent_requests: int = 64 tokenizer: Tokenizer = Field(default=None) @@ -212,6 +213,18 @@ def _get_batches(self, texts: list[str]) -> Iterable[list[str]]: if batch: yield batch + def _retry(self, func: Callable) -> Callable: + if self.max_retries is None or self.wait_time is None: + return func + + return retry( + retry=retry_if_exception_type( + (httpx.TimeoutException, httpx.HTTPStatusError) + ), + wait=wait_fixed(self.wait_time), + stop=stop_after_attempt(self.max_retries), + )(func) + def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed a list of document texts. @@ -225,13 +238,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: try: batch_responses = [] - @retry( - retry=retry_if_exception_type( - (httpx.TimeoutException, httpx.HTTPStatusError) - ), - wait=wait_fixed(self.wait_time), - stop=stop_after_attempt(self.max_retries), - ) + @self._retry def _embed_batch(batch: list[str]) -> Response: response = self.client.post( url="/embeddings", @@ -266,13 +273,7 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """ try: - @retry( - retry=retry_if_exception_type( - (httpx.TimeoutException, httpx.HTTPStatusError) - ), - wait=wait_fixed(self.wait_time), - stop=stop_after_attempt(self.max_retries), - ) + @self._retry async def _aembed_batch(batch: list[str]) -> Response: response = await self.async_client.post( url="/embeddings", diff --git a/libs/partners/mistralai/tests/integration_tests/test_embeddings.py b/libs/partners/mistralai/tests/integration_tests/test_embeddings.py index 3ef91728e2fd7..b5e0f2787cc9f 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_embeddings.py +++ b/libs/partners/mistralai/tests/integration_tests/test_embeddings.py @@ -35,7 +35,7 @@ async def test_mistralai_embedding_documents_async() -> None: assert len(output[0]) == 1024 -async def test_mistralai_embedding_documents_http_error_async() -> None: +async def test_mistralai_embedding_documents_tenacity_error_async() -> None: """Test MistralAI embeddings for documents.""" documents = ["foo bar", "test document"] embedding = MistralAIEmbeddings(max_retries=0) @@ -50,6 +50,21 @@ async def test_mistralai_embedding_documents_http_error_async() -> None: await embedding.aembed_documents(documents) +async def test_mistralai_embedding_documents_http_error_async() -> None: + """Test MistralAI embeddings for documents.""" + documents = ["foo bar", "test document"] + embedding = MistralAIEmbeddings(max_retries=None) + mock_response = httpx.Response( + status_code=400, + request=httpx.Request("POST", url=embedding.async_client.base_url), + ) + with ( + patch.object(embedding.async_client, "post", return_value=mock_response), + pytest.raises(httpx.HTTPStatusError), + ): + await embedding.aembed_documents(documents) + + async def test_mistralai_embedding_query_async() -> None: """Test MistralAI embeddings for query.""" document = "foo bar"