Skip to content

Commit

Permalink
Merge pull request #52 from weaviate/bigquery
Browse files Browse the repository at this point in the history
bigquery tests
  • Loading branch information
CShorten authored Jan 21, 2025
2 parents 4c5e1b4 + 3ff066f commit 704dcee
Show file tree
Hide file tree
Showing 7 changed files with 9,244 additions and 108 deletions.
41 changes: 30 additions & 11 deletions data/run-query-execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@
if collection_name in created_collections:
continue

# Delete if exists
# Check if collection exists
if weaviate_client.collections.exists(collection_name):
weaviate_client.collections.delete(collection_name)
print(f"Collection {collection_name} already exists, skipping creation")
created_collections.add(collection_name)
continue

properties = []
for prop in collection['properties']:
Expand Down Expand Up @@ -97,15 +99,32 @@
print(f"\033[91mQuery execution failed\033[0m") # Red text
print(f"Error: {str(e)}") # Print the error message
query_data['ground_truth_query_result'] = "QUERY EXECUTION FAILED"
print("Connecting to Weaviate...")
weaviate_client = weaviate.connect_to_weaviate_cloud(
cluster_url=WEAVIATE_URL,
auth_credentials=weaviate.auth.AuthApiKey(WEAVIATE_API_KEY),
headers={
"X-OpenAI-Api-Key": OPENAI_API_KEY
}
)
print("Successfully re-connected to Weaviate...")

# Exponential backoff retry logic
max_retries = 3
delay = 2 # Initial delay in seconds
retry_count = 0

while retry_count < max_retries:
try:
print(f"Attempt {retry_count + 1}/{max_retries} to reconnect to Weaviate...")
weaviate_client = weaviate.connect_to_weaviate_cloud(
cluster_url=WEAVIATE_URL,
auth_credentials=weaviate.auth.AuthApiKey(WEAVIATE_API_KEY),
headers={
"X-OpenAI-Api-Key": OPENAI_API_KEY
}
)
print("Successfully re-connected to Weaviate")
break
except Exception as retry_error:
retry_count += 1
if retry_count == max_retries:
print(f"Failed to reconnect after {max_retries} attempts")
raise retry_error
wait_time = delay * (2 ** (retry_count - 1)) # Exponential backoff
print(f"Reconnection failed. Waiting {wait_time} seconds before retry...")
time.sleep(wait_time)

print(f"\nTotal failed queries: {failed_queries}")

Expand Down
8,547 changes: 8,547 additions & 0 deletions data/synthetic-weaviate-queries-with-results.json

Large diffs are not rendered by default.

486 changes: 486 additions & 0 deletions notebooks/bigquery-pydantic-simple.ipynb

Large diffs are not rendered by default.

53 changes: 51 additions & 2 deletions src/lm/pydantic_agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,60 @@ def _stringify_aggregation_results(
"""Stringify the results of the aggregation executor."""
print(f"{Fore.GREEN}Stringifying {len(results)} aggregation results{Style.RESET_ALL}")
stringified_responses = []

for idx, (result, query) in enumerate(zip(results, queries)):
response = f"Aggregation Result {idx+1} (Query: {query}):\n"
for prop in result.properties:
response += f"{prop}:{result.properties[prop]}\n"

# Handle grouped results (AggregateGroupByReturn)
if hasattr(result, 'groups'):
response += f"Found {len(result.groups)} groups:\n"
for group in result.groups:
response += f"\nGroup: {group.grouped_by.prop} = {group.grouped_by.value}\n"
if hasattr(group, 'total_count'):
response += f"Total Count: {group.total_count}\n"

if hasattr(group, 'properties'):
for prop_name, metrics in group.properties.items():
response += f"\n{prop_name} metrics:\n"

# Handle text metrics with top occurrences
if hasattr(metrics, 'top_occurrences') and metrics.top_occurrences:
response += "Top occurrences:\n"
for occurrence in metrics.top_occurrences:
response += f" {occurrence.value}: {occurrence.count}\n"

# Handle numeric metrics
if hasattr(metrics, 'count') and metrics.count is not None:
response += f" count: {metrics.count}\n"
if hasattr(metrics, 'minimum') and metrics.minimum is not None:
response += f" minimum: {metrics.minimum}\n"
if hasattr(metrics, 'maximum') and metrics.maximum is not None:
response += f" maximum: {metrics.maximum}\n"
if hasattr(metrics, 'mean') and metrics.mean is not None:
response += f" mean: {metrics.mean:.2f}\n"
if hasattr(metrics, 'sum_') and metrics.sum_ is not None:
response += f" sum: {metrics.sum_}\n"

# Handle non-grouped results (AggregateReturn)
elif hasattr(result, 'properties'):
for prop_name, metrics in result.properties.items():
response += f"\n{prop_name} metrics:\n"
if hasattr(metrics, 'count') and metrics.count is not None:
response += f" count: {metrics.count}\n"
if hasattr(metrics, 'mean') and metrics.mean is not None:
response += f" mean: {metrics.mean:.2f}\n"
if hasattr(metrics, 'maximum') and metrics.maximum is not None:
response += f" maximum: {metrics.maximum}\n"
if hasattr(metrics, 'minimum') and metrics.minimum is not None:
response += f" minimum: {metrics.minimum}\n"
if hasattr(metrics, 'sum_') and metrics.sum_ is not None:
response += f" sum: {metrics.sum_}\n"

if hasattr(result, 'total_count'):
response += f"\nTotal Count: {result.total_count}\n"

stringified_responses.append(response)

return stringified_responses

def _update_usage(self, usage: Usage):
Expand Down
204 changes: 114 additions & 90 deletions src/lm/pydantic_agents/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,71 +107,102 @@ def create_filter(f: PropertyFilter) -> _FilterByProperty | None:

async def process_results(aggregation_results: List[AggregateReturn], search_results: List[QueryReturn]):
"""Process both aggregation and search results."""
print(f"\nReceived {len(search_results)} search results and {len(aggregation_results)} aggregation results")

# Process search results
if search_results:
for i, result in enumerate(search_results, 1):
print(f"\nSearch Result Set {i}:")
print(f"Number of objects in result set: {len(result.objects)}")
for item in result.objects:
print(f"Menu Item: {item.properties['menuItem']}")
print(f"Price: ${item.properties['price']:.2f}")
print(f"Vegetarian: {item.properties['isVegetarian']}")
print(f"\nObject properties: {list(item.properties.keys())}")
for prop, value in item.properties.items():
print(f"{prop}: {value}")
print("---")
else:
print("No search results to process")

# Process aggregation results
if aggregation_results:
print("\nAggregation Results:")
for agg_result in aggregation_results:
print(f"\nAggregation Results (count: {len(aggregation_results)}):")
for idx, agg_result in enumerate(aggregation_results):
print(f"\nProcessing aggregation result {idx + 1}")
print(f"Result type: {type(agg_result)}")
print(f"Available attributes: {dir(agg_result)}")

try:
# Handle AggregateGroupByReturn object
if hasattr(agg_result, 'groups') and agg_result.groups: # Check if groups exist and is not empty
for group in agg_result.groups:
print(f"\nGroup: {group.grouped_by.prop} = {group.grouped_by.value}")
print(f"Count: {group.total_count}")
# Handle grouped results (AggregateGroupByReturn)
if hasattr(agg_result, 'groups'):
print(f"Found grouped result with {len(agg_result.groups)} groups")
for group_idx, group in enumerate(agg_result.groups):
print(f"\nProcessing group {group_idx + 1}")
print(f"Group attributes: {dir(group)}")
print(f"Group: {group.grouped_by.prop} = {group.grouped_by.value}")

if hasattr(group, 'total_count'):
print(f"Total Count: {group.total_count}")

for prop_name, metrics in group.properties.items():
print(f"{prop_name} metrics:")
if isinstance(metrics, (AggregateInteger, AggregateNumber)):
if metrics.mean is not None:
print(f" mean: {metrics.mean:.2f}")
if metrics.maximum is not None:
print(f" maximum: {metrics.maximum}")
if metrics.minimum is not None:
print(f" minimum: {metrics.minimum}")
if metrics.count is not None:
print(f" count: {metrics.count}")
if metrics.sum_ is not None:
print(f" sum: {metrics.sum_}")
else:
# Convert metrics object to dictionary, filtering None values
metrics_dict = {k: v for k, v in vars(metrics).items() if not k.startswith('_') and v is not None}
for metric_name, value in metrics_dict.items():
print(f" {metric_name}: {value}")
if hasattr(group, 'properties'):
print(f"Properties in group: {list(group.properties.keys())}")
for prop_name, metrics in group.properties.items():
print(f"\n{prop_name} metrics (type: {type(metrics)}):")
print(f"Metrics attributes: {dir(metrics)}")

# Handle AggregateText specifically
if hasattr(metrics, 'top_occurrences') and metrics.top_occurrences:
print(" Top occurrences:")
for occurrence in metrics.top_occurrences:
print(f" {occurrence.value}: {occurrence.count}")

# Handle numeric metrics
if isinstance(metrics, (AggregateInteger, AggregateNumber)):
print(f" Available numeric metrics: count={metrics.count}, "
f"min={metrics.minimum}, max={metrics.maximum}, "
f"mean={metrics.mean}, sum={metrics.sum_}")
if metrics.count is not None:
print(f" count: {metrics.count}")
if metrics.minimum is not None:
print(f" minimum: {metrics.minimum}")
if metrics.maximum is not None:
print(f" maximum: {metrics.maximum}")
if metrics.mean is not None:
print(f" mean: {metrics.mean:.2f}")
if metrics.sum_ is not None:
print(f" sum: {metrics.sum_}")

# Handle AggregateReturn object
if hasattr(agg_result, 'properties'):
# Handle non-grouped results (AggregateReturn)
elif hasattr(agg_result, 'properties'):
print(f"Found non-grouped result with properties: {list(agg_result.properties.keys())}")
for prop_name, metrics in agg_result.properties.items():
print(f"\n{prop_name} metrics:")
print(f"\n{prop_name} metrics (type: {type(metrics)}):")
print(f"Metrics attributes: {dir(metrics)}")
if isinstance(metrics, (AggregateInteger, AggregateNumber)):
print(f" Available numeric metrics: count={metrics.count}, "
f"min={metrics.minimum}, max={metrics.maximum}, "
f"mean={metrics.mean}, sum={metrics.sum_}")
if metrics.count is not None:
print(f" count: {metrics.count}")
if metrics.mean is not None:
print(f" mean: {metrics.mean:.2f}")
if metrics.maximum is not None:
print(f" maximum: {metrics.maximum}")
if metrics.minimum is not None:
print(f" minimum: {metrics.minimum}")
if metrics.count is not None:
print(f" count: {metrics.count}")
if metrics.sum_ is not None:
print(f" sum: {metrics.sum_}")
else:
metrics_dict = {k: v for k, v in vars(metrics).items() if not k.startswith('_') and v is not None}
for metric_name, value in metrics_dict.items():
print(f" {metric_name}: {value}")

else:
print(f"Warning: Unexpected aggregation result type: {type(agg_result)}")
print(f"Available attributes: {dir(agg_result)}")

if hasattr(agg_result, 'total_count'):
print(f"\nTotal Count: {agg_result.total_count}")

except Exception as e:
print(f"Error processing aggregation result: {str(e)}")
print(f"Result type: {type(agg_result)}")
print(f"Error details: {repr(e)}")
print(f"Available attributes: {dir(agg_result)}")
raise # Re-raise the exception to see the full stack trace


async def aggregate(
Expand Down Expand Up @@ -229,9 +260,12 @@ def _build_return_metrics(
)
)
elif isinstance(agg, IntegerPropertyAggregation):
metric_name = agg.metrics.value.lower()
if metric_name == "sum":
metric_name = "sum_"
metrics_list.append(
wc.query.Metrics(agg.property_name).integer(
**{agg.metrics.value.lower(): True}
**{metric_name: True}
)
)
elif isinstance(agg, TextPropertyAggregation):
Expand All @@ -253,57 +287,47 @@ def _build_return_metrics(


async def main():
"""Example usage of the query executor."""
import os
import weaviate
from weaviate.classes.init import Auth
from coordinator.src.agents.nodes.query import QueryAgentDeps, query_agent

client = weaviate.use_async_with_weaviate_cloud(
cluster_url=os.getenv("WEAVIATE_URL"),
auth_credentials=Auth.api_key(os.getenv("WEAVIATE_API_KEY")),
headers={"X-Openai-Api-Key": os.getenv("OPENAI_APIKEY")},
)

try:
collection = client.collections.get("Menus")

# Example schema for menu items
schema = {
"properties": {
"menuItem": "string",
"itemDescription": "string",
"price": "number",
"isVegetarian": "boolean"
}
}

query_agent_deps = QueryAgentDeps(collection_schema=schema)

# Example query
query_result = await query_agent.run(
"Find vegetarian menu items under $20",
deps=query_agent_deps,
)

# Connect the client before searching
await client.connect()

# Execute search
search_results = await search(collection, query_result.data, limit=10)
"""Test AggregateGroupByReturn object parsing."""
from dataclasses import dataclass

@dataclass
class GroupedBy:
prop: str
value: str

@dataclass
class AggregateInteger:
count: int
maximum: None
mean: None
median: None
minimum: None
mode: None
sum_: None

@dataclass
class AggregateGroup:
grouped_by: GroupedBy
properties: dict
total_count: int

# Execute aggregation
agg_results = []
if hasattr(query_result.data, 'aggregations'):
for agg in query_result.data.aggregations:
agg_result = await aggregate(collection, agg)
agg_results.append(agg_result)

# Process and display results
await process_results(agg_results, search_results)

finally:
await client.close()
@dataclass
class AggregateGroupByReturn:
groups: list

# Create test object
test_results = [
AggregateGroupByReturn(groups=[
AggregateGroup(
grouped_by=GroupedBy(prop='openNow', value='true'),
properties={'name': AggregateInteger(count=13, maximum=None, mean=None, median=None, minimum=None, mode=None, sum_=None)},
total_count=13
)
])
]

# Process results
await process_results(test_results, [])


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 704dcee

Please sign in to comment.