Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local Jupyter Code Executor for Version 4 #4795

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/packages/autogen-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@ video-surfer = [
"ffmpeg-python",
"openai-whisper",
]

grpc = [
"grpcio~=1.62.0", # TODO: update this once we have a stable version.
]
jupyter-executor = [
"ipykernel>=6.29.5",
"jupyter-kernel-gateway>=3.0.1",
]

[tool.hatch.build.targets.wheel]
packages = ["src/autogen_ext"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from ._jupyter_client import JupyterClient
from ._jupyter_code_executor import JupyterCodeExecutor, JupyterCodeResult
from ._jupyter_connectable import JupyterConnectable
from ._jupyter_connection_info import JupyterConnectionInfo
from ._local_jupyter_server import LocalJupyterServer

__all__ = [
"JupyterConnectable",
"JupyterConnectionInfo",
"JupyterClient",
"LocalJupyterServer",
"JupyterCodeExecutor",
"JupyterCodeResult",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from __future__ import annotations

import sys
from dataclasses import dataclass
from types import TracebackType
from typing import Any, AsyncGenerator, cast

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

import datetime
import json
import uuid

import requests
from requests.adapters import HTTPAdapter, Retry
from websockets.asyncio.client import ClientConnection, connect

from ._jupyter_connection_info import JupyterConnectionInfo


class JupyterClient:
def __init__(self, connection_info: JupyterConnectionInfo):
"""(Experimental) A client for communicating with a Jupyter gateway server.

Args:
connection_info (JupyterConnectionInfo): Connection information
"""
self._connection_info = connection_info
self._session = requests.Session()
retries = Retry(total=5, backoff_factor=0.1)
self._session.mount("http://", HTTPAdapter(max_retries=retries))

def _get_headers(self) -> dict[str, str]:
if self._connection_info.token is None:
return {}
return {"Authorization": f"token {self._connection_info.token}"}

def _get_cookies(self) -> str:
cookies = self._session.cookies.get_dict()
return "; ".join([f"{name}={value}" for name, value in cookies.items()])

def _get_api_base_url(self) -> str:
protocol = "https" if self._connection_info.use_https else "http"
port = f":{self._connection_info.port}" if self._connection_info.port else ""
return f"{protocol}://{self._connection_info.host}{port}"

def _get_ws_base_url(self) -> str:
port = f":{self._connection_info.port}" if self._connection_info.port else ""
return f"ws://{self._connection_info.host}{port}"

def list_kernel_specs(self) -> dict[str, dict[str, str]]:
response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers())
return cast(dict[str, dict[str, str]], response.json())

def list_kernels(self) -> list[dict[str, str]]:
response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers())
return cast(list[dict[str, str]], response.json())

def start_kernel(self, kernel_spec_name: str) -> str:
"""Start a new kernel.

Args:
kernel_spec_name (str): Name of the kernel spec to start

Returns:
str: ID of the started kernel
"""

response = self._session.post(
f"{self._get_api_base_url()}/api/kernels",
headers=self._get_headers(),
json={"name": kernel_spec_name},
)
return cast(str, response.json()["id"])

def delete_kernel(self, kernel_id: str) -> None:
response = self._session.delete(
f"{self._get_api_base_url()}/api/kernels/{kernel_id}", headers=self._get_headers()
)
response.raise_for_status()

def restart_kernel(self, kernel_id: str) -> None:
response = self._session.post(
f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers()
)
response.raise_for_status()

async def get_kernel_client(self, kernel_id: str) -> JupyterKernelClient:
ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels"
headers = self._get_headers()
headers["Cookie"] = self._get_cookies()
websocket = await connect(ws_url, additional_headers=headers)
return JupyterKernelClient(websocket)


class JupyterKernelClient:
"""A client for communicating with a Jupyter kernel."""

@dataclass
class ExecutionResult:
@dataclass
class DataItem:
mime_type: str
data: str

is_ok: bool
output: str
data_items: list[DataItem]

def __init__(self, websocket: ClientConnection):
self._session_id: str = uuid.uuid4().hex
self._websocket = websocket

async def _send_message(self, *, content: dict[str, Any], channel: str, message_type: str) -> str:
timestamp = datetime.datetime.now().isoformat()
message_id = uuid.uuid4().hex
message = {
"header": {
"username": "autogen",
"version": "5.0",
"session": self._session_id,
"msg_id": message_id,
"msg_type": message_type,
"date": timestamp,
},
"parent_header": {},
"channel": channel,
"content": content,
"metadata": {},
"buffers": {},
}
await self._websocket.send(json.dumps(message))
return message_id

async def wait_for_ready(self) -> None:
message_id = await self._send_message(content={}, channel="shell", message_type="kernel_info_request")

async for message in self._receive_message(message_id):
if message["msg_type"] == "kernel_info_reply":
break

async def execute(self, code: str) -> ExecutionResult:
message_id = await self._send_message(
content={
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
channel="shell",
message_type="execute_request",
)

text_output: list[str] = []
data_output: list[JupyterKernelClient.ExecutionResult.DataItem] = []

async for message in self._receive_message(message_id):
content = message["content"]
match message["msg_type"]:
case "execute_result" | "display_data":
for data_type, data in content["data"].items():
match data_type:
case "text/plain":
text_output.append(data)
case type if type.startswith("image/") or type == "text/html":
data_output.append(self.ExecutionResult.DataItem(mime_type=data_type, data=data))
case _:
text_output.append(json.dumps(data))
case "stream":
text_output.append(content["text"])
case "error":
return JupyterKernelClient.ExecutionResult(
is_ok=False,
output="\n".join(content["traceback"]),
data_items=[],
)
case _:
pass

if message["msg_type"] == "status" and content["execution_state"] == "idle":
break

return JupyterKernelClient.ExecutionResult(
is_ok=True, output="\n".join([output for output in text_output]), data_items=data_output
)

async def __aenter__(self) -> Self:
return self

async def __aexit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
await self._websocket.close()

async def _receive_message(self, message_id: str) -> AsyncGenerator[dict[str, Any]]:
async for data in self._websocket:
if isinstance(data, bytes):
data = data.decode("utf-8")
message = cast(dict[str, Any], json.loads(data))
if message.get("parent_header", {}).get("msg_id") == message_id:
yield message
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import asyncio
import base64
import json
import re
import sys
import uuid
from dataclasses import dataclass
from pathlib import Path
from types import TracebackType

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from autogen_core import CancellationToken
from autogen_core.code_executor import CodeBlock, CodeExecutor, CodeResult

from .._common import silence_pip
from ._jupyter_connectable import JupyterConnectable


@dataclass
class JupyterCodeResult(CodeResult):
"""A code result class for Jupyter code executor."""

output_files: list[Path]


class JupyterCodeExecutor(CodeExecutor):
def __init__(
self,
server: JupyterConnectable,
kernel_name: str = "python3",
timeout: int = 60,
output_dir: Path = Path("."),
):
"""A code executor class that executes code statefully using
a Jupyter server supplied to this class.

Each execution is stateful and can access variables created from previous
executions in the same session.

Args:
server (JupyterConnectable): The Jupyter server to use.
kernel_name (str): The kernel name to use. Make sure it is installed.
By default, it is "python3".
timeout (int): The timeout for code execution, by default 60.
output_dir (Path): The directory to save output files, by default ".".
"""
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")

self._jupyter_client = server.get_client()
self._kernel_name = kernel_name
self._timeout = timeout
self._output_dir = output_dir
self.start()

async def execute_code_blocks(
self, code_blocks: list[CodeBlock], cancellation_token: CancellationToken
) -> JupyterCodeResult:
"""Execute code blocks and return the result.

Args:
code_blocks (list[CodeBlock]): The code blocks to execute.

Returns:
JupyterCodeResult: The result of the code execution.

Raises:
asyncio.TimeoutError: Code execution timeouts
asyncio.CancelledError: CancellationToken evoked during execution
"""
if self._kernel_id is None:
raise ValueError("Kernel not running")

async with await self._jupyter_client.get_kernel_client(self._kernel_id) as kernel_client:
wait_for_ready_task = asyncio.create_task(kernel_client.wait_for_ready())
cancellation_token.link_future(wait_for_ready_task)
await asyncio.wait_for(wait_for_ready_task, timeout=self._timeout)

outputs: list[str] = []
output_files: list[Path] = []
exit_code = 0

for code_block in code_blocks:
code = silence_pip(code_block.code, code_block.language)
execute_task = asyncio.create_task(kernel_client.execute(code))
cancellation_token.link_future(execute_task)
result = await asyncio.wait_for(execute_task, timeout=self._timeout)

# Clean ansi escape sequences
result.output = re.sub(r"\x1b\[[0-9;]*[A-Za-z]", "", result.output)
outputs.append(result.output)

if not result.is_ok:
exit_code = 1
break

for data in result.data_items:
match data.mime_type:
case "image/png":
path = self._save_image(data.data)
output_files.append(path)
case "image/jpeg":
# TODO: Should this also be encoded? Images are encoded as both png and jpg
pass
case "text/html":
path = self._save_html(data.data)
output_files.append(path)
case _:
outputs.append(json.dumps(data.data))

return JupyterCodeResult(exit_code=exit_code, output="\n".join(outputs), output_files=output_files)

async def restart(self) -> None:
"""Restart the code executor."""
if self._kernel_id is None:
self.start()
else:
self._jupyter_client.restart_kernel(self._kernel_id)
self._jupyter_kernel_client = self._jupyter_client.get_kernel_client(self._kernel_id)

def start(self) -> None:
"""Start the kernel."""
available_kernels = self._jupyter_client.list_kernel_specs()
if self._kernel_name not in available_kernels["kernelspecs"]:
raise ValueError(f"Kernel {self._kernel_name} is not installed.")

self._kernel_id = self._jupyter_client.start_kernel(self._kernel_name)

def stop(self) -> None:
"""Stop the kernel."""
if self._kernel_id is not None:
self._jupyter_client.delete_kernel(self._kernel_id)
self._kernel_id = None

def __enter__(self) -> Self:
return self

def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
self.stop()

def _save_image(self, image_data_base64: str) -> Path:
"""Save image data to a file."""
image_data = base64.b64decode(image_data_base64)
path = self._output_dir / f"{uuid.uuid4().hex}.png"
path.write_bytes(image_data)
return path.absolute()

def _save_html(self, html_data: str) -> Path:
"""Save html data to a file."""
path = self._output_dir / f"{uuid.uuid4().hex}.html"
path.write_text(html_data)
return path.absolute()
Loading