Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions integration/test_collection_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,3 +740,125 @@ def test_near_text_generate_with_dynamic_rag(
assert g0.debug is None
assert g0.metadata is None
assert g1.metadata is None


@pytest.mark.parametrize("parameter,answer", [("text", "yes"), ("content", "no")])
def test_contextualai_generative_search_single(
collection_factory: CollectionFactory, parameter: str, answer: str
) -> None:
"""Test Contextual AI generative search with single prompt."""
api_key = os.environ.get("CONTEXTUAL_API_KEY")
if api_key is None:
pytest.skip("No Contextual AI API key found.")

collection = collection_factory(
name="TestContextualAIGenerativeSingle",
generative_config=Configure.Generative.contextualai(
model="v2",
max_tokens=100,
temperature=0.1,
system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context. Answer with yes or no only.",
avoid_commentary=False,
),
vectorizer_config=Configure.Vectorizer.none(),
properties=[
Property(name="text", data_type=DataType.TEXT),
Property(name="content", data_type=DataType.TEXT),
],
headers={"X-Contextual-Api-Key": api_key},
ports=(8086, 50057),
)
if collection._connection._weaviate_version.is_lower_than(1, 23, 1):
pytest.skip("Generative search requires Weaviate 1.23.1 or higher")

collection.data.insert_many(
[
DataObject(properties={"text": "bananas are great", "content": "bananas are bad"}),
DataObject(properties={"text": "apples are great", "content": "apples are bad"}),
]
)

res = collection.generate.fetch_objects(
single_prompt=f"is it good or bad based on {{{parameter}}}? Just answer with yes or no without punctuation",
)
for obj in res.objects:
assert obj.generated is not None
assert obj.generated.lower() == answer
assert res.generated is None


def test_contextualai_generative_and_rerank_combined(collection_factory: CollectionFactory) -> None:
"""Test Contextual AI generative search combined with reranking."""
contextual_api_key = os.environ.get("CONTEXTUAL_API_KEY")
if contextual_api_key is None:
pytest.skip("No Contextual AI API key found.")

collection = collection_factory(
name="TestContextualAIGenerativeAndRerank",
generative_config=Configure.Generative.contextualai(
model="v2",
max_tokens=100,
temperature=0.1,
system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.",
avoid_commentary=False,
),
reranker_config=Configure.Reranker.contextualai(
model="ctxl-rerank-v2-instruct-multilingual",
instruction="Prioritize documents that contain the query term",
),
vectorizer_config=Configure.Vectorizer.text2vec_openai(),
properties=[Property(name="text", data_type=DataType.TEXT)],
headers={"X-Contextual-Api-Key": contextual_api_key},
ports=(8086, 50057),
)
if collection._connection._weaviate_version < _ServerVersion(1, 23, 1):
pytest.skip("Generative reranking requires Weaviate 1.23.1 or higher")

insert = collection.data.insert_many(
[{"text": "This is a test"}, {"text": "This is another test"}]
)
uuid1 = insert.uuids[0]
vector1 = collection.query.fetch_object_by_id(uuid1, include_vector=True).vector
assert vector1 is not None

for _idx, query in enumerate(
[
lambda: collection.generate.bm25(
"test",
rerank=Rerank(prop="text", query="another"),
single_prompt="What is it? {text}",
),
lambda: collection.generate.hybrid(
"test",
rerank=Rerank(prop="text", query="another"),
single_prompt="What is it? {text}",
),
lambda: collection.generate.near_object(
uuid1,
rerank=Rerank(prop="text", query="another"),
single_prompt="What is it? {text}",
),
lambda: collection.generate.near_vector(
vector1["default"],
rerank=Rerank(prop="text", query="another"),
single_prompt="What is it? {text}",
),
lambda: collection.generate.near_text(
"test",
rerank=Rerank(prop="text", query="another"),
single_prompt="What is it? {text}",
),
]
):
objects = query().objects
assert len(objects) == 2
assert objects[0].metadata.rerank_score is not None
assert objects[0].generated is not None
assert objects[1].metadata.rerank_score is not None
assert objects[1].generated is not None

assert [obj for obj in objects if "another" in obj.properties["text"]][ # type: ignore
0
].metadata.rerank_score > [
obj for obj in objects if "another" not in obj.properties["text"]
][0].metadata.rerank_score
58 changes: 58 additions & 0 deletions integration/test_collection_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,61 @@ def test_queries_with_rerank_and_group_by(collection_factory: CollectionFactory)
].rerank_score > [group for prop, group in ret.groups.items() if "another" not in prop][
0
].rerank_score


def test_queries_with_rerank_contextualai(collection_factory: CollectionFactory) -> None:
"""Test Contextual AI reranker with various query types."""
api_key = os.environ.get("CONTEXTUAL_API_KEY")
if api_key is None:
pytest.skip("No Contextual AI API key found.")

collection = collection_factory(
name="Test_test_queries_with_rerank_contextualai",
reranker_config=wvc.config.Configure.Reranker.contextualai(
model="ctxl-rerank-v2-instruct-multilingual",
instruction="Prioritize documents that contain the query term",
),
vectorizer_config=wvc.config.Configure.Vectorizer.text2vec_openai(),
properties=[wvc.config.Property(name="text", data_type=wvc.config.DataType.TEXT)],
headers={"X-Contextual-Api-Key": api_key},
ports=(8086, 50057),
)
if collection._connection._weaviate_version.is_lower_than(1, 23, 1):
pytest.skip("Reranking requires Weaviate 1.23.1 or higher")

insert = collection.data.insert_many(
[{"text": "This is a test"}, {"text": "This is another test"}]
)
uuid1 = insert.uuids[0]
vector1 = collection.query.fetch_object_by_id(uuid1, include_vector=True).vector
assert vector1 is not None

for _idx, query in enumerate(
[
lambda: collection.query.bm25(
"test", rerank=wvc.query.Rerank(prop="text", query="another")
),
lambda: collection.query.hybrid(
"test", rerank=wvc.query.Rerank(prop="text", query="another")
),
lambda: collection.query.near_object(
uuid1, rerank=wvc.query.Rerank(prop="text", query="another")
),
lambda: collection.query.near_vector(
vector1["default"], rerank=wvc.query.Rerank(prop="text", query="another")
),
lambda: collection.query.near_text(
"test", rerank=wvc.query.Rerank(prop="text", query="another")
),
]
):
objects = query().objects
assert len(objects) == 2
assert objects[0].metadata.rerank_score is not None
assert objects[1].metadata.rerank_score is not None

assert [obj for obj in objects if "another" in obj.properties["text"]][ # type: ignore
0
].metadata.rerank_score > [
obj for obj in objects if "another" not in obj.properties["text"]
][0].metadata.rerank_score
23 changes: 23 additions & 0 deletions test/collection/test_classes_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,29 @@ def test_generative_parameters_images_parsing(
),
),
),
(
GenerativeConfig.contextualai(
base_url="http://localhost:8080",
model="v2",
max_tokens=100,
temperature=0.5,
top_p=0.9,
system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.",
avoid_commentary=False,
)._to_grpc(_GenerativeConfigRuntimeOptions(return_metadata=True)),
generative_pb2.GenerativeProvider(
return_metadata=True,
contextualai=generative_pb2.GenerativeContextualAI(
base_url="http://localhost:8080",
model="v2",
max_tokens=100,
temperature=0.5,
top_p=0.9,
system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.",
avoid_commentary=False,
),
),
),
],
)
def test_generative_provider_to_grpc(
Expand Down
50 changes: 50 additions & 0 deletions test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,34 @@ def test_config_with_vectorizer_and_properties(
}
},
),
(
Configure.Generative.contextualai(),
{
"generative-contextualai": {},
},
),
(
Configure.Generative.contextualai(
model="v2",
max_tokens=512,
temperature=0.7,
top_p=0.9,
system_prompt="You are a helpful assistant that provides accurate and informative responses based on the given context.",
avoid_commentary=False,
base_url="https://api.contextual.ai",
),
{
"generative-contextualai": {
"model": "v2",
"maxTokensProperty": 512,
"temperatureProperty": 0.7,
"topPProperty": 0.9,
"systemPromptProperty": "You are a helpful assistant that provides accurate and informative responses based on the given context.",
"avoidCommentaryProperty": False,
"baseURL": "https://api.contextual.ai/",
}
},
),
]


Expand Down Expand Up @@ -1125,6 +1153,28 @@ def test_config_with_generative(
"reranker-transformers": {},
},
),
(
Configure.Reranker.contextualai(),
{
"reranker-contextualai": {},
},
),
(
Configure.Reranker.contextualai(
model="ctxl-rerank-v2-instruct-multilingual",
instruction="Prioritize recent documents",
top_n=5,
base_url="https://api.contextual.ai",
),
{
"reranker-contextualai": {
"model": "ctxl-rerank-v2-instruct-multilingual",
"instruction": "Prioritize recent documents",
"topN": 5,
"baseURL": "https://api.contextual.ai",
}
},
),
]


Expand Down
Loading