Skip to content

Commit 911c72f

Browse files
authored
Add wait_for_database helper function to poll for Knowledge Base database creation (#48)
1 parent dcef3d5 commit 911c72f

File tree

7 files changed

+451
-11
lines changed

7 files changed

+451
-11
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Example demonstrating how to use the wait_for_database helper function.
3+
4+
This example shows how to:
5+
1. Create a knowledge base
6+
2. Wait for its database to be ready
7+
3. Handle errors and timeouts appropriately
8+
"""
9+
10+
import os
11+
12+
from gradient import Gradient
13+
from gradient.resources.knowledge_bases import KnowledgeBaseTimeoutError, KnowledgeBaseDatabaseError
14+
15+
16+
def main() -> None:
17+
"""Create a knowledge base and wait for its database to be ready."""
18+
# Initialize the Gradient client
19+
# Note: DIGITALOCEAN_ACCESS_TOKEN must be set in your environment
20+
client = Gradient(
21+
access_token=os.environ.get("DIGITALOCEAN_ACCESS_TOKEN"),
22+
)
23+
24+
# Create a knowledge base
25+
# Replace these values with your actual configuration
26+
kb_response = client.knowledge_bases.create(
27+
name="My Knowledge Base",
28+
region="nyc1", # Choose your preferred region
29+
embedding_model_uuid="your-embedding-model-uuid", # Use your embedding model UUID
30+
)
31+
32+
if not kb_response.knowledge_base or not kb_response.knowledge_base.uuid:
33+
print("Failed to create knowledge base")
34+
return
35+
36+
kb_uuid = kb_response.knowledge_base.uuid
37+
print(f"Created knowledge base: {kb_uuid}")
38+
39+
try:
40+
# Wait for the database to be ready
41+
# Default: 10 minute timeout, 5 second poll interval
42+
print("Waiting for database to be ready...")
43+
result = client.knowledge_bases.wait_for_database(kb_uuid)
44+
print(f"Database status: {result.database_status}") # "ONLINE"
45+
print("Knowledge base is ready!")
46+
47+
# Alternative: Custom timeout and poll interval
48+
# result = client.knowledge_bases.wait_for_database(
49+
# kb_uuid,
50+
# timeout=900.0, # 15 minutes
51+
# poll_interval=10.0 # Check every 10 seconds
52+
# )
53+
54+
except KnowledgeBaseDatabaseError as e:
55+
# Database entered a failed state (DECOMMISSIONED or UNHEALTHY)
56+
print(f"Database failed: {e}")
57+
58+
except KnowledgeBaseTimeoutError as e:
59+
# Database did not become ready within the timeout period
60+
print(f"Timeout: {e}")
61+
62+
63+
if __name__ == "__main__":
64+
main()

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,5 +253,4 @@ known-first-party = ["gradient", "tests"]
253253
[tool.ruff.lint.per-file-ignores]
254254
"bin/**.py" = ["T201", "T203"]
255255
"scripts/**.py" = ["T201", "T203"]
256-
"tests/**.py" = ["T201", "T203"]
257256
"examples/**.py" = ["T201", "T203"]

src/gradient/resources/knowledge_bases/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
)
1919
from .knowledge_bases import (
2020
KnowledgeBasesResource,
21+
KnowledgeBaseTimeoutError,
22+
KnowledgeBaseDatabaseError,
2123
AsyncKnowledgeBasesResource,
2224
KnowledgeBasesResourceWithRawResponse,
2325
AsyncKnowledgeBasesResourceWithRawResponse,
@@ -40,6 +42,8 @@
4042
"AsyncIndexingJobsResourceWithStreamingResponse",
4143
"KnowledgeBasesResource",
4244
"AsyncKnowledgeBasesResource",
45+
"KnowledgeBaseDatabaseError",
46+
"KnowledgeBaseTimeoutError",
4347
"KnowledgeBasesResourceWithRawResponse",
4448
"AsyncKnowledgeBasesResourceWithRawResponse",
4549
"KnowledgeBasesResourceWithStreamingResponse",

src/gradient/resources/knowledge_bases/knowledge_bases.py

Lines changed: 182 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import time
6+
import asyncio
57
from typing import Iterable
68

79
import httpx
@@ -40,7 +42,24 @@
4042
from ...types.knowledge_base_update_response import KnowledgeBaseUpdateResponse
4143
from ...types.knowledge_base_retrieve_response import KnowledgeBaseRetrieveResponse
4244

43-
__all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource"]
45+
__all__ = [
46+
"KnowledgeBasesResource",
47+
"AsyncKnowledgeBasesResource",
48+
"KnowledgeBaseDatabaseError",
49+
"KnowledgeBaseTimeoutError",
50+
]
51+
52+
53+
class KnowledgeBaseDatabaseError(Exception):
54+
"""Raised when a knowledge base database enters a failed state."""
55+
56+
pass
57+
58+
59+
class KnowledgeBaseTimeoutError(Exception):
60+
"""Raised when waiting for a knowledge base database times out."""
61+
62+
pass
4463

4564

4665
class KnowledgeBasesResource(SyncAPIResource):
@@ -330,6 +349,81 @@ def delete(
330349
cast_to=KnowledgeBaseDeleteResponse,
331350
)
332351

352+
def wait_for_database(
353+
self,
354+
uuid: str,
355+
*,
356+
timeout: float = 600.0,
357+
poll_interval: float = 5.0,
358+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
359+
# The extra values given here take precedence over values defined on the client or passed to this method.
360+
extra_headers: Headers | None = None,
361+
extra_query: Query | None = None,
362+
extra_body: Body | None = None,
363+
) -> KnowledgeBaseRetrieveResponse:
364+
"""
365+
Poll the knowledge base until the database status is ONLINE or a failed state is reached.
366+
367+
This helper function repeatedly calls retrieve() to check the database_status field.
368+
It will wait for the database to become ONLINE, or raise an exception if it enters
369+
a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded.
370+
371+
Args:
372+
uuid: The knowledge base UUID to poll
373+
374+
timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes)
375+
376+
poll_interval: Time to wait between polls in seconds (default: 5 seconds)
377+
378+
extra_headers: Send extra headers
379+
380+
extra_query: Add additional query parameters to the request
381+
382+
extra_body: Add additional JSON properties to the request
383+
384+
Returns:
385+
The final KnowledgeBaseRetrieveResponse when the database status is ONLINE
386+
387+
Raises:
388+
KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY)
389+
390+
KnowledgeBaseTimeoutError: If the timeout is exceeded before the database becomes ONLINE
391+
"""
392+
if not uuid:
393+
raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}")
394+
395+
start_time = time.time()
396+
failed_states = {"DECOMMISSIONED", "UNHEALTHY"}
397+
398+
while True:
399+
elapsed = time.time() - start_time
400+
if elapsed >= timeout:
401+
raise KnowledgeBaseTimeoutError(
402+
f"Timeout waiting for knowledge base database to become ready. "
403+
f"Database did not reach ONLINE status within {timeout} seconds."
404+
)
405+
406+
response = self.retrieve(
407+
uuid,
408+
extra_headers=extra_headers,
409+
extra_query=extra_query,
410+
extra_body=extra_body,
411+
)
412+
413+
status = response.database_status
414+
415+
if status == "ONLINE":
416+
return response
417+
418+
if status in failed_states:
419+
raise KnowledgeBaseDatabaseError(f"Knowledge base database entered failed state: {status}")
420+
421+
# Sleep before next poll, but don't exceed timeout
422+
remaining_time = timeout - elapsed
423+
sleep_time = min(poll_interval, remaining_time)
424+
if sleep_time > 0:
425+
time.sleep(sleep_time)
426+
333427

334428
class AsyncKnowledgeBasesResource(AsyncAPIResource):
335429
@cached_property
@@ -618,6 +712,81 @@ async def delete(
618712
cast_to=KnowledgeBaseDeleteResponse,
619713
)
620714

715+
async def wait_for_database(
716+
self,
717+
uuid: str,
718+
*,
719+
timeout: float = 600.0,
720+
poll_interval: float = 5.0,
721+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
722+
# The extra values given here take precedence over values defined on the client or passed to this method.
723+
extra_headers: Headers | None = None,
724+
extra_query: Query | None = None,
725+
extra_body: Body | None = None,
726+
) -> KnowledgeBaseRetrieveResponse:
727+
"""
728+
Poll the knowledge base until the database status is ONLINE or a failed state is reached.
729+
730+
This helper function repeatedly calls retrieve() to check the database_status field.
731+
It will wait for the database to become ONLINE, or raise an exception if it enters
732+
a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded.
733+
734+
Args:
735+
uuid: The knowledge base UUID to poll
736+
737+
timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes)
738+
739+
poll_interval: Time to wait between polls in seconds (default: 5 seconds)
740+
741+
extra_headers: Send extra headers
742+
743+
extra_query: Add additional query parameters to the request
744+
745+
extra_body: Add additional JSON properties to the request
746+
747+
Returns:
748+
The final KnowledgeBaseRetrieveResponse when the database status is ONLINE
749+
750+
Raises:
751+
KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY)
752+
753+
KnowledgeBaseTimeoutError: If the timeout is exceeded before the database becomes ONLINE
754+
"""
755+
if not uuid:
756+
raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}")
757+
758+
start_time = time.time()
759+
failed_states = {"DECOMMISSIONED", "UNHEALTHY"}
760+
761+
while True:
762+
elapsed = time.time() - start_time
763+
if elapsed >= timeout:
764+
raise KnowledgeBaseTimeoutError(
765+
f"Timeout waiting for knowledge base database to become ready. "
766+
f"Database did not reach ONLINE status within {timeout} seconds."
767+
)
768+
769+
response = await self.retrieve(
770+
uuid,
771+
extra_headers=extra_headers,
772+
extra_query=extra_query,
773+
extra_body=extra_body,
774+
)
775+
776+
status = response.database_status
777+
778+
if status == "ONLINE":
779+
return response
780+
781+
if status in failed_states:
782+
raise KnowledgeBaseDatabaseError(f"Knowledge base database entered failed state: {status}")
783+
784+
# Sleep before next poll, but don't exceed timeout
785+
remaining_time = timeout - elapsed
786+
sleep_time = min(poll_interval, remaining_time)
787+
if sleep_time > 0:
788+
await asyncio.sleep(sleep_time)
789+
621790

622791
class KnowledgeBasesResourceWithRawResponse:
623792
def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None:
@@ -638,6 +807,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None:
638807
self.delete = to_raw_response_wrapper(
639808
knowledge_bases.delete,
640809
)
810+
self.wait_for_database = to_raw_response_wrapper(
811+
knowledge_bases.wait_for_database,
812+
)
641813

642814
@cached_property
643815
def data_sources(self) -> DataSourcesResourceWithRawResponse:
@@ -667,6 +839,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None:
667839
self.delete = async_to_raw_response_wrapper(
668840
knowledge_bases.delete,
669841
)
842+
self.wait_for_database = async_to_raw_response_wrapper(
843+
knowledge_bases.wait_for_database,
844+
)
670845

671846
@cached_property
672847
def data_sources(self) -> AsyncDataSourcesResourceWithRawResponse:
@@ -696,6 +871,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None:
696871
self.delete = to_streamed_response_wrapper(
697872
knowledge_bases.delete,
698873
)
874+
self.wait_for_database = to_streamed_response_wrapper(
875+
knowledge_bases.wait_for_database,
876+
)
699877

700878
@cached_property
701879
def data_sources(self) -> DataSourcesResourceWithStreamingResponse:
@@ -725,6 +903,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None:
725903
self.delete = async_to_streamed_response_wrapper(
726904
knowledge_bases.delete,
727905
)
906+
self.wait_for_database = async_to_streamed_response_wrapper(
907+
knowledge_bases.wait_for_database,
908+
)
728909

729910
@cached_property
730911
def data_sources(self) -> AsyncDataSourcesResourceWithStreamingResponse:

0 commit comments

Comments
 (0)