diff --git a/bc2/lib/embedding/openai.py b/bc2/lib/embedding/openai.py index cd49767..d1a4f41 100644 --- a/bc2/lib/embedding/openai.py +++ b/bc2/lib/embedding/openai.py @@ -107,7 +107,13 @@ def __init__( def embed(self, text: str) -> Embedding: text = self._trim_input(text) - result = self.client.embeddings.create(input=text, model=self.config.model) + params = { + "input": text, + "model": self.config.model, + } + if self.config.dimensions: + params["dimensions"] = self.config.dimensions + result = self.client.embeddings.create(**params) return self._format_result(result) async def embed_async(self, text: str) -> Embedding: diff --git a/bc2/lib/embedding/test_openai.py b/bc2/lib/embedding/test_openai.py index a98d9fb..bd555a0 100644 --- a/bc2/lib/embedding/test_openai.py +++ b/bc2/lib/embedding/test_openai.py @@ -1,9 +1,16 @@ +from unittest.mock import MagicMock + import pytest +from openai import AsyncOpenAI, OpenAI from bc2.core.common.openai import OpenAIClientConfig from bc2.core.common.openai_metadata import get_encoding_for_model -from .openai import OpenAIEmbeddingConfig, OpenAIEmbeddingGeneratorConfig +from .openai import ( + OpenAIEmbeddingConfig, + OpenAIEmbeddingDriver, + OpenAIEmbeddingGeneratorConfig, +) @pytest.mark.parametrize( @@ -33,3 +40,57 @@ def test_trim_input(model: str, text: str, trimmed: str): assert embed.driver._trim_input(text) == trimmed assert len(encoding.encode(trimmed)) <= embed.generator.max_input_tokens + + +def _mock_embedding_response(model: str = "text-embedding-3-large") -> MagicMock: + response = MagicMock() + response.data = [MagicMock(embedding=[0.0, 0.1, 0.2])] + response.model = model + return response + + +@pytest.mark.parametrize( + "dimensions", + [None, 256, 1024], + ids=["unset", "256", "1024"], +) +def test_embed_passes_dimensions(dimensions: int | None): + client = MagicMock(spec=OpenAI) + client.embeddings.create.return_value = _mock_embedding_response() + aclient = MagicMock(spec=AsyncOpenAI) + + config = OpenAIEmbeddingGeneratorConfig( + model="text-embedding-3-large", + dimensions=dimensions, + ) + driver = OpenAIEmbeddingDriver(client, aclient, config) + + driver.embed("hello") + + client.embeddings.create.assert_called_once() + call_kwargs = client.embeddings.create.call_args.kwargs + assert call_kwargs["input"] == "hello" + assert call_kwargs["model"] == "text-embedding-3-large" + if dimensions is None: + assert "dimensions" not in call_kwargs + else: + assert call_kwargs["dimensions"] == dimensions + + +def test_embed_config_dimensions_default(): + config = OpenAIEmbeddingConfig( + client=OpenAIClientConfig(api_key="test"), + generator=OpenAIEmbeddingGeneratorConfig(model="text-embedding-3-large"), + ) + assert config.generator.dimensions is None + + +def test_embed_config_dimensions_override(): + config = OpenAIEmbeddingConfig( + client=OpenAIClientConfig(api_key="test"), + generator=OpenAIEmbeddingGeneratorConfig( + model="text-embedding-3-large", dimensions=512 + ), + ) + assert config.generator.dimensions == 512 + assert config.generator.model_dimensions == 512