diff --git a/integration/test_collection_openai.py b/integration/test_collection_openai.py index 32d3de5f6..198eafac8 100644 --- a/integration/test_collection_openai.py +++ b/integration/test_collection_openai.py @@ -740,3 +740,125 @@ def test_near_text_generate_with_dynamic_rag( assert g0.debug is None assert g0.metadata is None assert g1.metadata is None + + +@pytest.mark.parametrize("parameter,answer", [("text", "yes"), ("content", "no")]) +def test_contextualai_generative_search_single( + collection_factory: CollectionFactory, parameter: str, answer: str +) -> None: + """Test Contextual AI generative search with single prompt.""" + api_key = os.environ.get("CONTEXTUAL_API_KEY") + if api_key is None: + pytest.skip("No Contextual AI API key found.") + + collection = collection_factory( + name="TestContextualAIGenerativeSingle", + generative_config=Configure.Generative.contextualai( + model="v2", + max_tokens=100, + temperature=0.1, + system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context. Answer with yes or no only.", + avoid_commentary=False, + ), + vectorizer_config=Configure.Vectorizer.none(), + properties=[ + Property(name="text", data_type=DataType.TEXT), + Property(name="content", data_type=DataType.TEXT), + ], + headers={"X-Contextual-Api-Key": api_key}, + ports=(8086, 50057), + ) + if collection._connection._weaviate_version.is_lower_than(1, 23, 1): + pytest.skip("Generative search requires Weaviate 1.23.1 or higher") + + collection.data.insert_many( + [ + DataObject(properties={"text": "bananas are great", "content": "bananas are bad"}), + DataObject(properties={"text": "apples are great", "content": "apples are bad"}), + ] + ) + + res = collection.generate.fetch_objects( + single_prompt=f"is it good or bad based on {{{parameter}}}? Just answer with yes or no without punctuation", + ) + for obj in res.objects: + assert obj.generated is not None + assert obj.generated.lower() == answer + assert res.generated is None + + +def test_contextualai_generative_and_rerank_combined(collection_factory: CollectionFactory) -> None: + """Test Contextual AI generative search combined with reranking.""" + contextual_api_key = os.environ.get("CONTEXTUAL_API_KEY") + if contextual_api_key is None: + pytest.skip("No Contextual AI API key found.") + + collection = collection_factory( + name="TestContextualAIGenerativeAndRerank", + generative_config=Configure.Generative.contextualai( + model="v2", + max_tokens=100, + temperature=0.1, + system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.", + avoid_commentary=False, + ), + reranker_config=Configure.Reranker.contextualai( + model="ctxl-rerank-v2-instruct-multilingual", + instruction="Prioritize documents that contain the query term", + ), + vectorizer_config=Configure.Vectorizer.text2vec_openai(), + properties=[Property(name="text", data_type=DataType.TEXT)], + headers={"X-Contextual-Api-Key": contextual_api_key}, + ports=(8086, 50057), + ) + if collection._connection._weaviate_version < _ServerVersion(1, 23, 1): + pytest.skip("Generative reranking requires Weaviate 1.23.1 or higher") + + insert = collection.data.insert_many( + [{"text": "This is a test"}, {"text": "This is another test"}] + ) + uuid1 = insert.uuids[0] + vector1 = collection.query.fetch_object_by_id(uuid1, include_vector=True).vector + assert vector1 is not None + + for _idx, query in enumerate( + [ + lambda: collection.generate.bm25( + "test", + rerank=Rerank(prop="text", query="another"), + single_prompt="What is it? {text}", + ), + lambda: collection.generate.hybrid( + "test", + rerank=Rerank(prop="text", query="another"), + single_prompt="What is it? {text}", + ), + lambda: collection.generate.near_object( + uuid1, + rerank=Rerank(prop="text", query="another"), + single_prompt="What is it? {text}", + ), + lambda: collection.generate.near_vector( + vector1["default"], + rerank=Rerank(prop="text", query="another"), + single_prompt="What is it? {text}", + ), + lambda: collection.generate.near_text( + "test", + rerank=Rerank(prop="text", query="another"), + single_prompt="What is it? {text}", + ), + ] + ): + objects = query().objects + assert len(objects) == 2 + assert objects[0].metadata.rerank_score is not None + assert objects[0].generated is not None + assert objects[1].metadata.rerank_score is not None + assert objects[1].generated is not None + + assert [obj for obj in objects if "another" in obj.properties["text"]][ # type: ignore + 0 + ].metadata.rerank_score > [ + obj for obj in objects if "another" not in obj.properties["text"] + ][0].metadata.rerank_score diff --git a/integration/test_collection_rerank.py b/integration/test_collection_rerank.py index 8c799f52d..61f6bade0 100644 --- a/integration/test_collection_rerank.py +++ b/integration/test_collection_rerank.py @@ -138,3 +138,61 @@ def test_queries_with_rerank_and_group_by(collection_factory: CollectionFactory) ].rerank_score > [group for prop, group in ret.groups.items() if "another" not in prop][ 0 ].rerank_score + + +def test_queries_with_rerank_contextualai(collection_factory: CollectionFactory) -> None: + """Test Contextual AI reranker with various query types.""" + api_key = os.environ.get("CONTEXTUAL_API_KEY") + if api_key is None: + pytest.skip("No Contextual AI API key found.") + + collection = collection_factory( + name="Test_test_queries_with_rerank_contextualai", + reranker_config=wvc.config.Configure.Reranker.contextualai( + model="ctxl-rerank-v2-instruct-multilingual", + instruction="Prioritize documents that contain the query term", + ), + vectorizer_config=wvc.config.Configure.Vectorizer.text2vec_openai(), + properties=[wvc.config.Property(name="text", data_type=wvc.config.DataType.TEXT)], + headers={"X-Contextual-Api-Key": api_key}, + ports=(8086, 50057), + ) + if collection._connection._weaviate_version.is_lower_than(1, 23, 1): + pytest.skip("Reranking requires Weaviate 1.23.1 or higher") + + insert = collection.data.insert_many( + [{"text": "This is a test"}, {"text": "This is another test"}] + ) + uuid1 = insert.uuids[0] + vector1 = collection.query.fetch_object_by_id(uuid1, include_vector=True).vector + assert vector1 is not None + + for _idx, query in enumerate( + [ + lambda: collection.query.bm25( + "test", rerank=wvc.query.Rerank(prop="text", query="another") + ), + lambda: collection.query.hybrid( + "test", rerank=wvc.query.Rerank(prop="text", query="another") + ), + lambda: collection.query.near_object( + uuid1, rerank=wvc.query.Rerank(prop="text", query="another") + ), + lambda: collection.query.near_vector( + vector1["default"], rerank=wvc.query.Rerank(prop="text", query="another") + ), + lambda: collection.query.near_text( + "test", rerank=wvc.query.Rerank(prop="text", query="another") + ), + ] + ): + objects = query().objects + assert len(objects) == 2 + assert objects[0].metadata.rerank_score is not None + assert objects[1].metadata.rerank_score is not None + + assert [obj for obj in objects if "another" in obj.properties["text"]][ # type: ignore + 0 + ].metadata.rerank_score > [ + obj for obj in objects if "another" not in obj.properties["text"] + ][0].metadata.rerank_score diff --git a/test/collection/test_classes_generative.py b/test/collection/test_classes_generative.py index 4be69bbf2..69c96a399 100644 --- a/test/collection/test_classes_generative.py +++ b/test/collection/test_classes_generative.py @@ -414,6 +414,27 @@ def test_generative_parameters_images_parsing( ), ), ), + ( + GenerativeConfig.contextualai( + model="v2", + max_tokens=100, + temperature=0.5, + top_p=0.9, + system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.", + avoid_commentary=False, + )._to_grpc(_GenerativeConfigRuntimeOptions(return_metadata=True)), + generative_pb2.GenerativeProvider( + return_metadata=True, + contextualai=generative_pb2.GenerativeContextualAI( + model="v2", + max_tokens=100, + temperature=0.5, + top_p=0.9, + system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.", + avoid_commentary=False, + ), + ), + ), ], ) def test_generative_provider_to_grpc( diff --git a/test/collection/test_config.py b/test/collection/test_config.py index fb836cd8a..71eb5ca71 100644 --- a/test/collection/test_config.py +++ b/test/collection/test_config.py @@ -1043,6 +1043,32 @@ def test_config_with_vectorizer_and_properties( } }, ), + ( + Configure.Generative.contextualai(), + { + "generative-contextualai": {}, + }, + ), + ( + Configure.Generative.contextualai( + model="v2", + max_tokens=512, + temperature=0.7, + top_p=0.9, + system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.", + avoid_commentary=False, + ), + { + "generative-contextualai": { + "model": "v2", + "maxTokensProperty": 512, + "temperatureProperty": 0.7, + "topPProperty": 0.9, + "systemPromptProperty": "You are a helpful assistant that provides accurate and informative responses based on the given context.", + "avoidCommentaryProperty": False, + } + }, + ), ] @@ -1125,6 +1151,26 @@ def test_config_with_generative( "reranker-transformers": {}, }, ), + ( + Configure.Reranker.contextualai(), + { + "reranker-contextualai": {}, + }, + ), + ( + Configure.Reranker.contextualai( + model="ctxl-rerank-v2-instruct-multilingual", + instruction="Prioritize recent documents", + top_n=5, + ), + { + "reranker-contextualai": { + "model": "ctxl-rerank-v2-instruct-multilingual", + "instruction": "Prioritize recent documents", + "topN": 5, + } + }, + ), ] diff --git a/weaviate/collections/classes/config.py b/weaviate/collections/classes/config.py index 8bad9617d..618635ccd 100644 --- a/weaviate/collections/classes/config.py +++ b/weaviate/collections/classes/config.py @@ -187,6 +187,7 @@ class GenerativeSearches(str, BaseEnum): ANTHROPIC: Weaviate module backed by Anthropic generative models. ANYSCALE: Weaviate module backed by Anyscale generative models. COHERE: Weaviate module backed by Cohere generative models. + CONTEXTUALAI: Weaviate module backed by ContextualAI generative models. DATABRICKS: Weaviate module backed by Databricks generative models. FRIENDLIAI: Weaviate module backed by FriendliAI generative models. MISTRAL: Weaviate module backed by Mistral generative models. @@ -200,6 +201,7 @@ class GenerativeSearches(str, BaseEnum): ANTHROPIC = "generative-anthropic" ANYSCALE = "generative-anyscale" COHERE = "generative-cohere" + CONTEXTUALAI = "generative-contextualai" DATABRICKS = "generative-databricks" DUMMY = "generative-dummy" FRIENDLIAI = "generative-friendliai" @@ -220,6 +222,7 @@ class Rerankers(str, BaseEnum): Attributes: NONE: No reranker. COHERE: Weaviate module backed by Cohere reranking models. + CONTEXTUALAI: Weaviate module backed by ContextualAI reranking models. TRANSFORMERS: Weaviate module backed by Transformers reranking models. VOYAGEAI: Weaviate module backed by VoyageAI reranking models. JINAAI: Weaviate module backed by JinaAI reranking models. @@ -228,6 +231,7 @@ class Rerankers(str, BaseEnum): NONE = "none" COHERE = "reranker-cohere" + CONTEXTUALAI = "reranker-contextualai" TRANSFORMERS = "reranker-transformers" VOYAGEAI = "reranker-voyageai" JINAAI = "reranker-jinaai" @@ -457,6 +461,18 @@ def _to_dict(self) -> Dict[str, Any]: return ret_dict +class _GenerativeContextualAIConfig(_GenerativeProvider): + generative: Union[GenerativeSearches, _EnumLikeStr] = Field( + default=GenerativeSearches.CONTEXTUALAI, frozen=True, exclude=True + ) + model: Optional[str] + maxTokensProperty: Optional[int] + temperatureProperty: Optional[float] + topPProperty: Optional[float] + systemPromptProperty: Optional[str] + avoidCommentaryProperty: Optional[bool] + + class _GenerativeGoogleConfig(_GenerativeProvider): generative: Union[GenerativeSearches, _EnumLikeStr] = Field( default=GenerativeSearches.PALM, frozen=True, exclude=True @@ -562,6 +578,22 @@ def _to_dict(self) -> Dict[str, Any]: return ret_dict +RerankerContextualAIModel = Literal[ + "ctxl-rerank-v2-instruct-multilingual", + "ctxl-rerank-v2-instruct-multilingual-mini", + "ctxl-rerank-v1-instruct", +] + + +class _RerankerContextualAIConfig(_RerankerProvider): + reranker: Union[Rerankers, _EnumLikeStr] = Field( + default=Rerankers.CONTEXTUALAI, frozen=True, exclude=True + ) + model: Optional[Union[RerankerContextualAIModel, str]] = Field(default=None) + instruction: Optional[str] = Field(default=None) + topN: Optional[int] = Field(default=None) + + class _Generative: """Use this factory class to create the correct object for the `generative_config` argument in the `collections.create()` method. @@ -829,6 +861,37 @@ def cohere( temperatureProperty=temperature, ) + @staticmethod + def contextualai( + model: Optional[str] = None, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + system_prompt: Optional[str] = None, + avoid_commentary: Optional[bool] = None, + ) -> _GenerativeProvider: + """Create a `_GenerativeContextualAIConfig` object for use when performing AI generation using the `generative-contextualai` module. + + See the [documentation](https://weaviate.io/developers/weaviate/model-providers/contextualai/generative) + for detailed usage. + + Args: + model: The model to use. Defaults to `None`, which uses the server-defined default + max_tokens: The maximum number of tokens to generate. Defaults to `None`, which uses the server-defined default + temperature: The temperature to use. Defaults to `None`, which uses the server-defined default + top_p: Nucleus sampling parameter (0 < x <= 1). Defaults to `None`, which uses the server-defined default + system_prompt: System instructions the model follows. Defaults to `None`, which uses the server-defined default + avoid_commentary: If `True`, reduce conversational commentary in responses. Defaults to `None`, which uses the server-defined default + """ + return _GenerativeContextualAIConfig( + maxTokensProperty=max_tokens, + model=model, + temperatureProperty=temperature, + topPProperty=top_p, + systemPromptProperty=system_prompt, + avoidCommentaryProperty=avoid_commentary, + ) + @staticmethod @docstring_deprecated( deprecated_in="4.9.0", @@ -1048,6 +1111,28 @@ def nvidia( """ return _RerankerNvidiaConfig(model=model, baseURL=base_url) + @staticmethod + def contextualai( + model: Optional[str] = None, + instruction: Optional[str] = None, + top_n: Optional[int] = None, + ) -> _RerankerProvider: + """Create a `_RerankerContextualAIConfig` object for use when reranking using the `reranker-contextualai` module. + + See the [documentation](https://weaviate.io/developers/weaviate/model-providers/contextualai/reranker) + for detailed usage. + + Args: + model: The model to use. Defaults to `None`, which uses the server-defined default + instruction: Custom instructions for reranking. Defaults to `None`. + top_n: Number of top results to return. Defaults to `None`, which uses the server-defined default. + """ + return _RerankerContextualAIConfig( + model=model, + instruction=instruction, + topN=top_n + ) + class _CollectionConfigCreateBase(_ConfigCreateModel): description: Optional[str] = Field(default=None) diff --git a/weaviate/collections/classes/generative.py b/weaviate/collections/classes/generative.py index 65ed6369a..6e807706f 100644 --- a/weaviate/collections/classes/generative.py +++ b/weaviate/collections/classes/generative.py @@ -445,6 +445,32 @@ def _to_grpc(self, opts: _GenerativeConfigRuntimeOptions) -> generative_pb2.Gene ) +class _GenerativeContextualAI(_GenerativeConfigRuntime): + generative: Union[GenerativeSearches, _EnumLikeStr] = Field( + default=GenerativeSearches.CONTEXTUALAI, frozen=True, exclude=True + ) + model: Optional[str] + max_tokens: Optional[int] + temperature: Optional[float] + top_p: Optional[float] + system_prompt: Optional[str] + avoid_commentary: Optional[bool] + + def _to_grpc(self, opts: _GenerativeConfigRuntimeOptions) -> generative_pb2.GenerativeProvider: + self._validate_multi_modal(opts) + return generative_pb2.GenerativeProvider( + return_metadata=opts.return_metadata, + contextualai=generative_pb2.GenerativeContextualAI( + model=self.model, + max_tokens=self.max_tokens, + temperature=self.temperature, + top_p=self.top_p, + system_prompt=self.system_prompt, + avoid_commentary=self.avoid_commentary or False, + ), + ) + + class GenerativeConfig: """Use this factory class to create the correct object for the `generative_provider` argument in the search methods of the `.generate` namespace. @@ -580,6 +606,35 @@ def cohere( temperature=temperature, ) + @staticmethod + def contextualai( + *, + model: Optional[str] = None, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + system_prompt: Optional[str] = None, + avoid_commentary: Optional[bool] = None, + ) -> _GenerativeConfigRuntime: + """Create a `_GenerativeContextualAI` object for use with the `generative-contextualai` module. + + Args: + model: The model to use. Defaults to `None`, which uses the server-defined default + max_tokens: The maximum number of tokens to generate. Defaults to `None`, which uses the server-defined default + temperature: The temperature to use. Defaults to `None`, which uses the server-defined default + top_p: The top P to use. Defaults to `None`, which uses the server-defined default + system_prompt: The system prompt to prepend to the conversation + avoid_commentary: Whether to avoid model commentary in responses + """ + return _GenerativeContextualAI( + model=model, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + system_prompt=system_prompt, + avoid_commentary=avoid_commentary, + ) + @staticmethod def databricks( *,