diff --git a/python/examples/misc/local_clickhouse_verify.py b/python/examples/misc/local_clickhouse_verify.py index 60d256ad05..210e771e3c 100644 --- a/python/examples/misc/local_clickhouse_verify.py +++ b/python/examples/misc/local_clickhouse_verify.py @@ -86,21 +86,28 @@ def wait_for_sync( def normalize_datetime(value: str | datetime) -> datetime: if isinstance(value, datetime): + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) return value - return datetime.fromisoformat(value.replace("Z", "+00:00")) + parsed = datetime.fromisoformat(value.replace("Z", "+00:00")) + if parsed.tzinfo is None: + return parsed.replace(tzinfo=timezone.utc) + return parsed -def assert_span_fields(span: Any) -> None: +def assert_span_fields(span: Any, require_end_time: bool) -> None: assert span.name, "span.name is required" assert span.trace_id, "span.trace_id is required" assert span.span_id, "span.span_id is required" assert span.start_time, "span.start_time is required" - assert span.end_time, "span.end_time is required" + if require_end_time: + assert span.end_time, "span.end_time is required" assert span.duration_ms is not None, "span.duration_ms is required" assert span.duration_ms >= 0, "span.duration_ms must be non-negative" start_time = normalize_datetime(span.start_time) - end_time = normalize_datetime(span.end_time) - assert end_time >= start_time, "span.end_time must be >= span.start_time" + if require_end_time: + end_time = normalize_datetime(span.end_time) + assert end_time >= start_time, "span.end_time must be >= span.start_time" def assert_recent_span(span: Any, now: datetime, window_minutes: int = 10) -> None: @@ -118,7 +125,10 @@ def assert_trace_contains_names(spans: Iterable[Any], expected: set[str]) -> Non def verify_search_api( - client: Mirascope, search_result: SearchSearchResponse, expected_names: set[str] + client: Mirascope, + search_result: SearchSearchResponse, + expected_names: set[str], + span_name: str, ) -> bool: """Verify Search API endpoints with the synced data.""" spans = search_result.spans or [] @@ -126,7 +136,7 @@ def verify_search_api( print(f" Found {total} spans (showing {len(spans)})") for span in spans[:3]: - assert_span_fields(span) + assert_span_fields(span, require_end_time=False) dur = span.duration_ms dur_str = f"{dur}ms" if dur is not None else "NULL" model = span.model or "N/A" @@ -147,8 +157,8 @@ def verify_search_api( print(f" Trace {trace_id[:16]}... has {len(trace_spans)} spans") assert trace_spans, "trace detail has no spans" for span in trace_spans: - assert_span_fields(span) - assert_trace_contains_names(trace_spans, expected_names) + assert_span_fields(span, require_end_time=True) + assert_trace_contains_names(trace_spans, {span_name}) now = datetime.now(timezone.utc) for span in trace_spans: assert_recent_span(span, now) @@ -227,7 +237,7 @@ def main(): # 4. Verify Search API print("\n[4/4] Verifying Search API...") - success = verify_search_api(api_client, search_result, expected_names) + success = verify_search_api(api_client, search_result, expected_names, span_name) provider.shutdown()