Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bigquery tests #52

Merged
merged 3 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions data/run-query-execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@
if collection_name in created_collections:
continue

# Delete if exists
# Check if collection exists
if weaviate_client.collections.exists(collection_name):
weaviate_client.collections.delete(collection_name)
print(f"Collection {collection_name} already exists, skipping creation")
created_collections.add(collection_name)
continue

properties = []
for prop in collection['properties']:
Expand Down Expand Up @@ -97,15 +99,32 @@
print(f"\033[91mQuery execution failed\033[0m") # Red text
print(f"Error: {str(e)}") # Print the error message
query_data['ground_truth_query_result'] = "QUERY EXECUTION FAILED"
print("Connecting to Weaviate...")
weaviate_client = weaviate.connect_to_weaviate_cloud(
cluster_url=WEAVIATE_URL,
auth_credentials=weaviate.auth.AuthApiKey(WEAVIATE_API_KEY),
headers={
"X-OpenAI-Api-Key": OPENAI_API_KEY
}
)
print("Successfully re-connected to Weaviate...")

# Exponential backoff retry logic
max_retries = 3
delay = 2 # Initial delay in seconds
retry_count = 0

while retry_count < max_retries:
try:
print(f"Attempt {retry_count + 1}/{max_retries} to reconnect to Weaviate...")
weaviate_client = weaviate.connect_to_weaviate_cloud(
cluster_url=WEAVIATE_URL,
auth_credentials=weaviate.auth.AuthApiKey(WEAVIATE_API_KEY),
headers={
"X-OpenAI-Api-Key": OPENAI_API_KEY
}
)
print("Successfully re-connected to Weaviate")
break
except Exception as retry_error:
retry_count += 1
if retry_count == max_retries:
print(f"Failed to reconnect after {max_retries} attempts")
raise retry_error
wait_time = delay * (2 ** (retry_count - 1)) # Exponential backoff
print(f"Reconnection failed. Waiting {wait_time} seconds before retry...")
time.sleep(wait_time)

print(f"\nTotal failed queries: {failed_queries}")

Expand Down
8,547 changes: 8,547 additions & 0 deletions data/synthetic-weaviate-queries-with-results.json

Large diffs are not rendered by default.

486 changes: 486 additions & 0 deletions notebooks/bigquery-pydantic-simple.ipynb

Large diffs are not rendered by default.

53 changes: 51 additions & 2 deletions src/lm/pydantic_agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,60 @@ def _stringify_aggregation_results(
"""Stringify the results of the aggregation executor."""
print(f"{Fore.GREEN}Stringifying {len(results)} aggregation results{Style.RESET_ALL}")
stringified_responses = []

for idx, (result, query) in enumerate(zip(results, queries)):
response = f"Aggregation Result {idx+1} (Query: {query}):\n"
for prop in result.properties:
response += f"{prop}:{result.properties[prop]}\n"

# Handle grouped results (AggregateGroupByReturn)
if hasattr(result, 'groups'):
response += f"Found {len(result.groups)} groups:\n"
for group in result.groups:
response += f"\nGroup: {group.grouped_by.prop} = {group.grouped_by.value}\n"
if hasattr(group, 'total_count'):
response += f"Total Count: {group.total_count}\n"

if hasattr(group, 'properties'):
for prop_name, metrics in group.properties.items():
response += f"\n{prop_name} metrics:\n"

# Handle text metrics with top occurrences
if hasattr(metrics, 'top_occurrences') and metrics.top_occurrences:
response += "Top occurrences:\n"
for occurrence in metrics.top_occurrences:
response += f" {occurrence.value}: {occurrence.count}\n"

# Handle numeric metrics
if hasattr(metrics, 'count') and metrics.count is not None:
response += f" count: {metrics.count}\n"
if hasattr(metrics, 'minimum') and metrics.minimum is not None:
response += f" minimum: {metrics.minimum}\n"
if hasattr(metrics, 'maximum') and metrics.maximum is not None:
response += f" maximum: {metrics.maximum}\n"
if hasattr(metrics, 'mean') and metrics.mean is not None:
response += f" mean: {metrics.mean:.2f}\n"
if hasattr(metrics, 'sum_') and metrics.sum_ is not None:
response += f" sum: {metrics.sum_}\n"

# Handle non-grouped results (AggregateReturn)
elif hasattr(result, 'properties'):
for prop_name, metrics in result.properties.items():
response += f"\n{prop_name} metrics:\n"
if hasattr(metrics, 'count') and metrics.count is not None:
response += f" count: {metrics.count}\n"
if hasattr(metrics, 'mean') and metrics.mean is not None:
response += f" mean: {metrics.mean:.2f}\n"
if hasattr(metrics, 'maximum') and metrics.maximum is not None:
response += f" maximum: {metrics.maximum}\n"
if hasattr(metrics, 'minimum') and metrics.minimum is not None:
response += f" minimum: {metrics.minimum}\n"
if hasattr(metrics, 'sum_') and metrics.sum_ is not None:
response += f" sum: {metrics.sum_}\n"

if hasattr(result, 'total_count'):
response += f"\nTotal Count: {result.total_count}\n"

stringified_responses.append(response)

return stringified_responses

def _update_usage(self, usage: Usage):
Expand Down
204 changes: 114 additions & 90 deletions src/lm/pydantic_agents/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,71 +107,102 @@ def create_filter(f: PropertyFilter) -> _FilterByProperty | None:

async def process_results(aggregation_results: List[AggregateReturn], search_results: List[QueryReturn]):
"""Process both aggregation and search results."""
print(f"\nReceived {len(search_results)} search results and {len(aggregation_results)} aggregation results")

# Process search results
if search_results:
for i, result in enumerate(search_results, 1):
print(f"\nSearch Result Set {i}:")
print(f"Number of objects in result set: {len(result.objects)}")
for item in result.objects:
print(f"Menu Item: {item.properties['menuItem']}")
print(f"Price: ${item.properties['price']:.2f}")
print(f"Vegetarian: {item.properties['isVegetarian']}")
print(f"\nObject properties: {list(item.properties.keys())}")
for prop, value in item.properties.items():
print(f"{prop}: {value}")
print("---")
else:
print("No search results to process")

# Process aggregation results
if aggregation_results:
print("\nAggregation Results:")
for agg_result in aggregation_results:
print(f"\nAggregation Results (count: {len(aggregation_results)}):")
for idx, agg_result in enumerate(aggregation_results):
print(f"\nProcessing aggregation result {idx + 1}")
print(f"Result type: {type(agg_result)}")
print(f"Available attributes: {dir(agg_result)}")

try:
# Handle AggregateGroupByReturn object
if hasattr(agg_result, 'groups') and agg_result.groups: # Check if groups exist and is not empty
for group in agg_result.groups:
print(f"\nGroup: {group.grouped_by.prop} = {group.grouped_by.value}")
print(f"Count: {group.total_count}")
# Handle grouped results (AggregateGroupByReturn)
if hasattr(agg_result, 'groups'):
print(f"Found grouped result with {len(agg_result.groups)} groups")
for group_idx, group in enumerate(agg_result.groups):
print(f"\nProcessing group {group_idx + 1}")
print(f"Group attributes: {dir(group)}")
print(f"Group: {group.grouped_by.prop} = {group.grouped_by.value}")

if hasattr(group, 'total_count'):
print(f"Total Count: {group.total_count}")

for prop_name, metrics in group.properties.items():
print(f"{prop_name} metrics:")
if isinstance(metrics, (AggregateInteger, AggregateNumber)):
if metrics.mean is not None:
print(f" mean: {metrics.mean:.2f}")
if metrics.maximum is not None:
print(f" maximum: {metrics.maximum}")
if metrics.minimum is not None:
print(f" minimum: {metrics.minimum}")
if metrics.count is not None:
print(f" count: {metrics.count}")
if metrics.sum_ is not None:
print(f" sum: {metrics.sum_}")
else:
# Convert metrics object to dictionary, filtering None values
metrics_dict = {k: v for k, v in vars(metrics).items() if not k.startswith('_') and v is not None}
for metric_name, value in metrics_dict.items():
print(f" {metric_name}: {value}")
if hasattr(group, 'properties'):
print(f"Properties in group: {list(group.properties.keys())}")
for prop_name, metrics in group.properties.items():
print(f"\n{prop_name} metrics (type: {type(metrics)}):")
print(f"Metrics attributes: {dir(metrics)}")

# Handle AggregateText specifically
if hasattr(metrics, 'top_occurrences') and metrics.top_occurrences:
print(" Top occurrences:")
for occurrence in metrics.top_occurrences:
print(f" {occurrence.value}: {occurrence.count}")

# Handle numeric metrics
if isinstance(metrics, (AggregateInteger, AggregateNumber)):
print(f" Available numeric metrics: count={metrics.count}, "
f"min={metrics.minimum}, max={metrics.maximum}, "
f"mean={metrics.mean}, sum={metrics.sum_}")
if metrics.count is not None:
print(f" count: {metrics.count}")
if metrics.minimum is not None:
print(f" minimum: {metrics.minimum}")
if metrics.maximum is not None:
print(f" maximum: {metrics.maximum}")
if metrics.mean is not None:
print(f" mean: {metrics.mean:.2f}")
if metrics.sum_ is not None:
print(f" sum: {metrics.sum_}")

# Handle AggregateReturn object
if hasattr(agg_result, 'properties'):
# Handle non-grouped results (AggregateReturn)
elif hasattr(agg_result, 'properties'):
print(f"Found non-grouped result with properties: {list(agg_result.properties.keys())}")
for prop_name, metrics in agg_result.properties.items():
print(f"\n{prop_name} metrics:")
print(f"\n{prop_name} metrics (type: {type(metrics)}):")
print(f"Metrics attributes: {dir(metrics)}")
if isinstance(metrics, (AggregateInteger, AggregateNumber)):
print(f" Available numeric metrics: count={metrics.count}, "
f"min={metrics.minimum}, max={metrics.maximum}, "
f"mean={metrics.mean}, sum={metrics.sum_}")
if metrics.count is not None:
print(f" count: {metrics.count}")
if metrics.mean is not None:
print(f" mean: {metrics.mean:.2f}")
if metrics.maximum is not None:
print(f" maximum: {metrics.maximum}")
if metrics.minimum is not None:
print(f" minimum: {metrics.minimum}")
if metrics.count is not None:
print(f" count: {metrics.count}")
if metrics.sum_ is not None:
print(f" sum: {metrics.sum_}")
else:
metrics_dict = {k: v for k, v in vars(metrics).items() if not k.startswith('_') and v is not None}
for metric_name, value in metrics_dict.items():
print(f" {metric_name}: {value}")

else:
print(f"Warning: Unexpected aggregation result type: {type(agg_result)}")
print(f"Available attributes: {dir(agg_result)}")

if hasattr(agg_result, 'total_count'):
print(f"\nTotal Count: {agg_result.total_count}")

except Exception as e:
print(f"Error processing aggregation result: {str(e)}")
print(f"Result type: {type(agg_result)}")
print(f"Error details: {repr(e)}")
print(f"Available attributes: {dir(agg_result)}")
raise # Re-raise the exception to see the full stack trace


async def aggregate(
Expand Down Expand Up @@ -229,9 +260,12 @@ def _build_return_metrics(
)
)
elif isinstance(agg, IntegerPropertyAggregation):
metric_name = agg.metrics.value.lower()
if metric_name == "sum":
metric_name = "sum_"
metrics_list.append(
wc.query.Metrics(agg.property_name).integer(
**{agg.metrics.value.lower(): True}
**{metric_name: True}
)
)
elif isinstance(agg, TextPropertyAggregation):
Expand All @@ -253,57 +287,47 @@ def _build_return_metrics(


async def main():
"""Example usage of the query executor."""
import os
import weaviate
from weaviate.classes.init import Auth
from coordinator.src.agents.nodes.query import QueryAgentDeps, query_agent

client = weaviate.use_async_with_weaviate_cloud(
cluster_url=os.getenv("WEAVIATE_URL"),
auth_credentials=Auth.api_key(os.getenv("WEAVIATE_API_KEY")),
headers={"X-Openai-Api-Key": os.getenv("OPENAI_APIKEY")},
)

try:
collection = client.collections.get("Menus")

# Example schema for menu items
schema = {
"properties": {
"menuItem": "string",
"itemDescription": "string",
"price": "number",
"isVegetarian": "boolean"
}
}

query_agent_deps = QueryAgentDeps(collection_schema=schema)

# Example query
query_result = await query_agent.run(
"Find vegetarian menu items under $20",
deps=query_agent_deps,
)

# Connect the client before searching
await client.connect()

# Execute search
search_results = await search(collection, query_result.data, limit=10)
"""Test AggregateGroupByReturn object parsing."""
from dataclasses import dataclass

@dataclass
class GroupedBy:
prop: str
value: str

@dataclass
class AggregateInteger:
count: int
maximum: None
mean: None
median: None
minimum: None
mode: None
sum_: None

@dataclass
class AggregateGroup:
grouped_by: GroupedBy
properties: dict
total_count: int

# Execute aggregation
agg_results = []
if hasattr(query_result.data, 'aggregations'):
for agg in query_result.data.aggregations:
agg_result = await aggregate(collection, agg)
agg_results.append(agg_result)

# Process and display results
await process_results(agg_results, search_results)

finally:
await client.close()
@dataclass
class AggregateGroupByReturn:
groups: list

# Create test object
test_results = [
AggregateGroupByReturn(groups=[
AggregateGroup(
grouped_by=GroupedBy(prop='openNow', value='true'),
properties={'name': AggregateInteger(count=13, maximum=None, mean=None, median=None, minimum=None, mode=None, sum_=None)},
total_count=13
)
])
]

# Process results
await process_results(test_results, [])


if __name__ == "__main__":
Expand Down
Loading