Skip to content

Commit 198759e

Browse files
authored
fix(py/genkit): ai.run() -> ai.run_main() and fix docstrings/commentary. (#2731)
fix(py/genkit): `ai.run()` -> `ai.run_main()` and fix docstrings/commentary. CHANGELOG: - [ ] Update remaining instances of `ai.run()` -> `ai.run_main()` throughout. - [ ] Fix docstrings and commentary. - [ ] Fix types for `run_main()`.
1 parent d38c193 commit 198759e

File tree

7 files changed

+77
-34
lines changed

7 files changed

+77
-34
lines changed

py/packages/genkit/src/genkit/ai/_base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import threading
2121
from collections.abc import Coroutine
2222
from http.server import HTTPServer
23-
from typing import Any
23+
from typing import Any, TypeVar
2424

2525
import structlog
2626

@@ -36,6 +36,8 @@
3636

3737
logger = structlog.get_logger(__name__)
3838

39+
T = TypeVar('T')
40+
3941

4042
class GenkitBase(GenkitRegistry):
4143
"""Base class with shared infra for Genkit instances (sync and async)."""
@@ -58,7 +60,7 @@ def __init__(
5860
self._initialize_server(reflection_server_spec)
5961
self._initialize_registry(model, plugins)
6062

61-
def run_main(self, coro: Coroutine[Any, Any, Any] | None = None) -> Any:
63+
def run_main(self, coro: Coroutine[Any, Any, T] | None = None) -> T:
6264
"""Runs the provided coroutine on an event loop.
6365
6466
Args:
@@ -67,7 +69,6 @@ def run_main(self, coro: Coroutine[Any, Any, Any] | None = None) -> Any:
6769
Returns:
6870
The result of the coroutine.
6971
"""
70-
7172
if not coro:
7273

7374
async def blank_coro():
@@ -78,7 +79,7 @@ async def blank_coro():
7879
result = None
7980
if self._loop:
8081

81-
async def run() -> Any:
82+
async def run() -> T:
8283
return await coro
8384

8485
result = run_async(self._loop, run)

py/packages/genkit/src/genkit/ai/_base_async.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#
1515
# SPDX-License-Identifier: Apache-2.0
1616

17-
"""Base/shared implementation for Genkit user-facing API."""
17+
"""Asynchronous server gateway interface implementation for Genkit."""
1818

1919
from collections.abc import Coroutine
2020
from typing import Any, TypeVar
@@ -91,15 +91,17 @@ def resolver(kind, name, plugin=plugin):
9191
else:
9292
raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin`')
9393

94-
def run(self, coro: Coroutine[Any, Any, T]) -> T:
94+
def run_main(self, coro: Coroutine[Any, Any, T]) -> T:
9595
"""Run the user's main coroutine.
9696
97-
In development mode (`GENKIT_ENV=dev`), this starts the Genkit reflection
98-
server and runs the user's coroutine concurrently within the same event loop,
99-
blocking until the server is stopped (e.g., via Ctrl+C).
97+
In development mode (`GENKIT_ENV=dev`), this starts the Genkit
98+
reflection server and runs the user's coroutine concurrently within the
99+
same event loop, blocking until the server is stopped (e.g., via
100+
Ctrl+C).
100101
101102
In production mode, this simply runs the user's coroutine to completion
102-
using `asyncio.run()`.
103+
using `uvloop.run()` for performance if available, otherwise
104+
`asyncio.run()`.
103105
104106
Args:
105107
coro: The main coroutine provided by the user.
@@ -111,7 +113,6 @@ def run(self, coro: Coroutine[Any, Any, T]) -> T:
111113
logger.info('Running in production mode.')
112114
return run_loop(coro)
113115

114-
# Development mode: Start reflection server and user coro concurrently.
115116
logger.info('Running in development mode.')
116117

117118
spec = self._reflection_server_spec
@@ -138,24 +139,23 @@ async def run_user_coro_wrapper():
138139
finally:
139140
user_task_finished_event.set()
140141

141-
reflection_server = make_reflection_server(self.registry, spec)
142+
reflection_server = _make_reflection_server(self.registry, spec)
142143

143144
try:
144145
async with RuntimeManager(spec):
145-
# We use anyio's task group because it's compatible with
146+
# We use anyio.TaskGroup because it is compatible with
146147
# asyncio's event loop and works with Python 3.10
147-
# (asyncio.create_task_group was added in 3.11, and we can switch
148-
# to that if we drop support for 3.10).
148+
# (asyncio.TaskGroup was added in 3.11, and we can switch to
149+
# that when we drop support for 3.10).
149150
async with anyio.create_task_group() as tg:
150151
# Start reflection server in the background.
151152
tg.start_soon(reflection_server.serve, name='genkit-reflection-server')
152-
await anyio.sleep(0.2)
153-
logger.info(f'Started Genkit reflection server at {spec.scheme}://{spec.host}:{spec.port}')
153+
logger.info(f'Started Genkit reflection server at {spec.url}')
154154

155155
# Start the (potentially short-lived) user coroutine wrapper
156156
tg.start_soon(run_user_coro_wrapper, name='genkit-user-coroutine')
157157

158-
# Block here until the TaskGroup is cancelled (e.g. Ctrl+C)
158+
# Block here until the task group is canceled (e.g. Ctrl+C)
159159
# or a task raises an unhandled exception. It should not
160160
# exit just because the user coroutine finishes.
161161

@@ -166,7 +166,7 @@ async def run_user_coro_wrapper():
166166
logger.exception(e)
167167
raise
168168

169-
# After the TaskGroup finishes (error or cancellation)
169+
# After the TaskGroup finishes (error or cancelation).
170170
if user_task_finished_event.is_set():
171171
logger.debug('User coroutine finished before TaskGroup exit.')
172172
return user_result
@@ -177,8 +177,19 @@ async def run_user_coro_wrapper():
177177
return anyio.run(dev_runner)
178178

179179

180-
def make_reflection_server(registry: GenkitRegistry, spec: ServerSpec) -> uvicorn.Server:
181-
"""Make a reflection server for the given registry and spec."""
180+
def _make_reflection_server(registry: GenkitRegistry, spec: ServerSpec) -> uvicorn.Server:
181+
"""Make a reflection server for the given registry and spec.
182+
183+
This is a helper function to make it easier to test the reflection server
184+
in isolation.
185+
186+
Args:
187+
registry: The registry to use for the reflection server.
188+
spec: The spec to use for the reflection server.
189+
190+
Returns:
191+
A uvicorn server instance.
192+
"""
182193
app = create_reflection_asgi_app(registry=registry)
183194
config = uvicorn.Config(app, host=spec.host, port=spec.port, loop='asyncio')
184195
return uvicorn.Server(config)

py/packages/genkit/src/genkit/ai/_runtime.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,18 +145,18 @@ def __init__(self, spec: ServerSpec, runtime_dir: str | Path | None = None):
145145
"""
146146
self.spec = spec
147147
if runtime_dir is None:
148-
self.runtime_dir = Path(os.getcwd()) / DEFAULT_RUNTIME_DIR_NAME
148+
self._runtime_dir = Path(os.getcwd()) / DEFAULT_RUNTIME_DIR_NAME
149149
else:
150-
self.runtime_dir = Path(runtime_dir)
150+
self._runtime_dir = Path(runtime_dir)
151151

152152
self._runtime_file_path: Path | None = None
153153

154154
async def __aenter__(self) -> RuntimeManager:
155155
"""Create the runtime directory and file."""
156156
try:
157-
await logger.adebug(f'Ensuring runtime directory exists: {self.runtime_dir}')
158-
self.runtime_dir.mkdir(parents=True, exist_ok=True)
159-
runtime_file_path = _create_and_write_runtime_file(self.runtime_dir, self.spec)
157+
await logger.adebug(f'Ensuring runtime directory exists: {self._runtime_dir}')
158+
self._runtime_dir.mkdir(parents=True, exist_ok=True)
159+
runtime_file_path = _create_and_write_runtime_file(self._runtime_dir, self.spec)
160160
_register_atexit_cleanup_handler(runtime_file_path)
161161

162162
except Exception as e:
@@ -186,9 +186,9 @@ async def __aexit__(
186186
def __enter__(self) -> RuntimeManager:
187187
"""Synchronous entry point: Create the runtime directory and file."""
188188
try:
189-
logger.debug(f'[sync] Ensuring runtime directory exists: {self.runtime_dir}')
190-
self.runtime_dir.mkdir(parents=True, exist_ok=True)
191-
self._runtime_file_path = _create_and_write_runtime_file(self.runtime_dir, self.spec)
189+
logger.debug(f'[sync] Ensuring runtime directory exists: {self._runtime_dir}')
190+
self._runtime_dir.mkdir(parents=True, exist_ok=True)
191+
self._runtime_file_path = _create_and_write_runtime_file(self._runtime_dir, self.spec)
192192
_register_atexit_cleanup_handler(self._runtime_file_path)
193193

194194
except Exception as e:

py/samples/firestore-retreiver/src/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ async def index_documents() -> None:
9393

9494
@ai.flow()
9595
async def retreive_documents():
96+
"""Retrieves the film documents from Firestore."""
9697
return await ai.retrieve(
9798
query=Document.from_text('sci-fi film'),
9899
retriever=firestore_action_name('filmsretriever'),

py/samples/google-genai-context-caching/src/context_caching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,4 @@ async def main() -> None:
101101

102102

103103
if __name__ == '__main__':
104-
ai.run(main())
104+
ai.run_main(main())

py/samples/short-n-long/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,9 @@ To start as a server in dev mode:
3939
```bash
4040
genkit start -- uv run src/short_n_long/main.py --server
4141
```
42+
43+
## Running with a specific version of Python
44+
45+
```bash
46+
genkit start -- uv run --python python3.10 src/short_n_long/main.py
47+
```

py/samples/short-n-long/src/short_n_long/main.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"""
4343

4444
import argparse
45+
import asyncio
4546

4647
import structlog
4748
import uvicorn
@@ -114,8 +115,8 @@ async def simple_generate_with_tools_flow(value: int) -> str:
114115
return response.text
115116

116117

117-
@ai.tool(name='gablorkenTool2')
118-
def gablorken_tool2(input_: GablorkenInput, ctx: ToolRunContext):
118+
@ai.tool(name='interruptingTool')
119+
def interrupting_tool(input_: GablorkenInput, ctx: ToolRunContext):
119120
"""The user-defined tool function.
120121
121122
Args:
@@ -146,7 +147,7 @@ async def simple_generate_with_interrupts(value: int) -> str:
146147
content=[TextPart(text=f'what is a gablorken of {value}')],
147148
),
148149
],
149-
tools=['gablorkenTool2'],
150+
tools=['interruptingTool'],
150151
)
151152
await logger.ainfo(f'len(response.tool_requests)={len(response1.tool_requests)}')
152153
if len(response1.interrupts) == 0:
@@ -233,6 +234,29 @@ async def say_hi_stream(name: str, ctx):
233234
return result
234235

235236

237+
@ai.flow()
238+
async def stream_greeting(name: str, ctx) -> str:
239+
"""Stream a greeting for the given name.
240+
241+
Args:
242+
name: the name to send to test function
243+
ctx: the context of the tool
244+
245+
Returns:
246+
The generated response with a function.
247+
"""
248+
chunks = [
249+
'hello',
250+
name,
251+
'how are you?',
252+
]
253+
for data in chunks:
254+
await asyncio.sleep(1)
255+
ctx.send_chunk(data)
256+
257+
return 'test streaming response'
258+
259+
236260
class Skills(BaseModel):
237261
"""Skills for an RPG character."""
238262

@@ -358,4 +382,4 @@ async def main(ai: Genkit) -> None:
358382
if __name__ == '__main__':
359383
config: argparse.Namespace = parse_args()
360384
runner = server_main if config.server else main
361-
ai.run(runner(ai))
385+
ai.run_main(runner(ai))

0 commit comments

Comments
 (0)