Skip to content

Commit

Permalink
Implement local jupyter notebook execution support
Browse files Browse the repository at this point in the history
  • Loading branch information
Leon0402 committed Dec 24, 2024
1 parent b15551c commit f53e14a
Show file tree
Hide file tree
Showing 9 changed files with 999 additions and 349 deletions.
4 changes: 3 additions & 1 deletion python/packages/autogen-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ video-surfer = [
"ffmpeg-python",
"openai-whisper",
]

grpc = [
"grpcio~=1.62.0", # TODO: update this once we have a stable version.
]
jupyter-executor = [
"jupyter-kernel-gateway>=3.0.1",
]

[tool.hatch.build.targets.wheel]
packages = ["src/autogen_ext"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from ._jupyter_client import JupyterClient
from ._jupyter_code_executor import JupyterCodeExecutor, JupyterCodeResult
from ._jupyter_connectable import JupyterConnectable
from ._jupyter_connection_info import JupyterConnectionInfo
from ._local_jupyter_server import LocalJupyterServer

__all__ = [
"JupyterConnectable",
"JupyterConnectionInfo",
"JupyterClient",
"LocalJupyterServer",
"JupyterCodeExecutor",
"JupyterCodeResult",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from __future__ import annotations

import sys
from dataclasses import dataclass
from types import TracebackType
from typing import Any, AsyncGenerator, cast

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

import datetime
import json
import uuid

import requests
from requests.adapters import HTTPAdapter, Retry
from websockets.asyncio.client import ClientConnection, connect

from ._jupyter_connection_info import JupyterConnectionInfo


class JupyterClient:
def __init__(self, connection_info: JupyterConnectionInfo):
"""(Experimental) A client for communicating with a Jupyter gateway server.
Args:
connection_info (JupyterConnectionInfo): Connection information
"""
self._connection_info = connection_info
self._session = requests.Session()
retries = Retry(total=5, backoff_factor=0.1)
self._session.mount("http://", HTTPAdapter(max_retries=retries))

def _get_headers(self) -> dict[str, str]:
if self._connection_info.token is None:
return {}
return {"Authorization": f"token {self._connection_info.token}"}

def _get_cookies(self) -> str:
cookies = self._session.cookies.get_dict()
return "; ".join([f"{name}={value}" for name, value in cookies.items()])

def _get_api_base_url(self) -> str:
protocol = "https" if self._connection_info.use_https else "http"
port = f":{self._connection_info.port}" if self._connection_info.port else ""
return f"{protocol}://{self._connection_info.host}{port}"

def _get_ws_base_url(self) -> str:
port = f":{self._connection_info.port}" if self._connection_info.port else ""
return f"ws://{self._connection_info.host}{port}"

def list_kernel_specs(self) -> dict[str, dict[str, str]]:
response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers())
return cast(dict[str, dict[str, str]], response.json())

def list_kernels(self) -> list[dict[str, str]]:
response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers())
return cast(list[dict[str, str]], response.json())

def start_kernel(self, kernel_spec_name: str) -> str:
"""Start a new kernel.
Args:
kernel_spec_name (str): Name of the kernel spec to start
Returns:
str: ID of the started kernel
"""

response = self._session.post(
f"{self._get_api_base_url()}/api/kernels",
headers=self._get_headers(),
json={"name": kernel_spec_name},
)
return cast(str, response.json()["id"])

def delete_kernel(self, kernel_id: str) -> None:
response = self._session.delete(
f"{self._get_api_base_url()}/api/kernels/{kernel_id}", headers=self._get_headers()
)
response.raise_for_status()

def restart_kernel(self, kernel_id: str) -> None:
response = self._session.post(
f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers()
)
response.raise_for_status()

async def get_kernel_client(self, kernel_id: str) -> JupyterKernelClient:
ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels"
headers = self._get_headers()
headers["Cookie"] = self._get_cookies()
websocket = await connect(ws_url, additional_headers=headers)
return JupyterKernelClient(websocket)


class JupyterKernelClient:
"""A client for communicating with a Jupyter kernel."""

@dataclass
class ExecutionResult:
@dataclass
class DataItem:
mime_type: str
data: str

is_ok: bool
output: str
data_items: list[DataItem]

def __init__(self, websocket: ClientConnection):
self._session_id: str = uuid.uuid4().hex
self._websocket = websocket

async def _send_message(self, *, content: dict[str, Any], channel: str, message_type: str) -> str:
timestamp = datetime.datetime.now().isoformat()
message_id = uuid.uuid4().hex
message = {
"header": {
"username": "autogen",
"version": "5.0",
"session": self._session_id,
"msg_id": message_id,
"msg_type": message_type,
"date": timestamp,
},
"parent_header": {},
"channel": channel,
"content": content,
"metadata": {},
"buffers": {},
}
await self._websocket.send(json.dumps(message))
return message_id

async def wait_for_ready(self) -> None:
message_id = await self._send_message(content={}, channel="shell", message_type="kernel_info_request")

async for message in self._receive_message(message_id):
if message["msg_type"] == "kernel_info_reply":
break

async def execute(self, code: str) -> ExecutionResult:
message_id = await self._send_message(
content={
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
channel="shell",
message_type="execute_request",
)

text_output: list[str] = []
data_output: list[JupyterKernelClient.ExecutionResult.DataItem] = []

async for message in self._receive_message(message_id):
content = message["content"]
match message["msg_type"]:
case "execute_result" | "display_data":
for data_type, data in content["data"].items():
match data_type:
case "text/plain":
text_output.append(data)
case data if data.startswith("image/") or data == "text/html":
data_output.append(self.ExecutionResult.DataItem(mime_type=data_type, data=data))
case _:
text_output.append(json.dumps(data))
case "stream":
text_output.append(content["text"])
case "error":
return JupyterKernelClient.ExecutionResult(
is_ok=False,
output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}",
data_items=[],
)
case _:
pass

if message["msg_type"] == "status" and content["execution_state"] == "idle":
break

return JupyterKernelClient.ExecutionResult(
is_ok=True, output="\n".join([output for output in text_output]), data_items=data_output
)

async def __aenter__(self) -> Self:
return self

async def __aexit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
await self._websocket.close()

async def _receive_message(self, message_id: str) -> AsyncGenerator[dict[str, Any]]:
async for data in self._websocket:
if isinstance(data, bytes):
data = data.decode("utf-8")
message = cast(dict[str, Any], json.loads(data))
if message.get("parent_header", {}).get("msg_id") == message_id:
yield message
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import asyncio
import base64
import json
import sys
import uuid
from dataclasses import dataclass
from pathlib import Path
from types import TracebackType

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from autogen_core import CancellationToken
from autogen_core.code_executor import CodeBlock, CodeExecutor, CodeResult

from .._common import silence_pip
from ._jupyter_connectable import JupyterConnectable


@dataclass
class JupyterCodeResult(CodeResult):
"""A code result class for Jupyter code executor."""

output_files: list[Path]


class JupyterCodeExecutor(CodeExecutor):
def __init__(
self,
server: JupyterConnectable,
kernel_name: str = "python3",
timeout: int = 60,
output_dir: Path = Path("."),
):
"""A code executor class that executes code statefully using
a Jupyter server supplied to this class.
Each execution is stateful and can access variables created from previous
executions in the same session.
Args:
server (JupyterConnectable): The Jupyter server to use.
kernel_name (str): The kernel name to use. Make sure it is installed.
By default, it is "python3".
timeout (int): The timeout for code execution, by default 60.
output_dir (Path): The directory to save output files, by default ".".
"""
if timeout < 1:
raise ValueError("Timeout must be greater than or equal to 1.")

self._jupyter_client = server.get_client()
self._kernel_name = kernel_name
self._timeout = timeout
self._output_dir = output_dir
self.start()

async def execute_code_blocks(
self, code_blocks: list[CodeBlock], cancellation_token: CancellationToken
) -> JupyterCodeResult:
"""Execute code blocks and return the result.
Args:
code_blocks (list[CodeBlock]): The code blocks to execute.
Returns:
JupyterCodeResult: The result of the code execution.
Raises:
asyncio.TimeoutError: Code execution timeouts
asyncio.CancelledError: CancellationToken evoked during execution
"""
async with await self._jupyter_client.get_kernel_client(self._kernel_id) as kernel_client:
wait_for_ready_task = asyncio.create_task(kernel_client.wait_for_ready())
cancellation_token.link_future(wait_for_ready_task)
await asyncio.wait_for(wait_for_ready_task, timeout=self._timeout)

outputs: list[str] = []
output_files: list[Path] = []
for code_block in code_blocks:
code = silence_pip(code_block.code, code_block.language)
execute_task = asyncio.create_task(kernel_client.execute(code))
cancellation_token.link_future(execute_task)
result = await asyncio.wait_for(execute_task, timeout=self._timeout)

if result.is_ok:
outputs.append(result.output)
for data in result.data_items:
match data.mime_type:
case "image/png":
path = self._save_image(data.data)
outputs.append(f"Image data saved to {path}")
output_files.append(path)
case "text/html":
path = self._save_html(data.data)
outputs.append(f"HTML data saved to {path}")
output_files.append(path)
case _:
outputs.append(json.dumps(data.data))
else:
return JupyterCodeResult(exit_code=1, output=f"ERROR: {result.output}", output_files=[])

return JupyterCodeResult(
exit_code=0, output="\n".join([output for output in outputs]), output_files=output_files
)

async def restart(self) -> None:
"""Restart the code executor."""
self._jupyter_client.restart_kernel(self._kernel_id)
self._jupyter_kernel_client = self._jupyter_client.get_kernel_client(self._kernel_id)

def start(self) -> None:
"""Start the kernel."""
available_kernels = self._jupyter_client.list_kernel_specs()
if self._kernel_name not in available_kernels["kernelspecs"]:
raise ValueError(f"Kernel {self._kernel_name} is not installed.")

self._kernel_id = self._jupyter_client.start_kernel(self._kernel_name)

def stop(self) -> None:
"""Stop the kernel."""
if self._kernel_id is not None:
self._jupyter_client.delete_kernel(self._kernel_id)
self._kernel_id = None

def __enter__(self) -> Self:
return self

def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
self.stop()

def _save_image(self, image_data_base64: str) -> Path:
"""Save image data to a file."""
image_data = base64.b64decode(image_data_base64)
path = self._output_dir / f"{uuid.uuid4().hex}.png"
path.write_bytes(image_data)
return path.absolute()

def _save_html(self, html_data: str) -> Path:
"""Save html data to a file."""
path = self._output_dir / f"{uuid.uuid4().hex}.html"
path.write_text(html_data)
return path.absolute()
Loading

0 comments on commit f53e14a

Please sign in to comment.