Skip to content

Commit

Permalink
feat: initial tool
Browse files Browse the repository at this point in the history
  • Loading branch information
plutoless committed Dec 30, 2024
1 parent b5bdedc commit 4e6aae7
Show file tree
Hide file tree
Showing 16 changed files with 246 additions and 4 deletions.
3 changes: 3 additions & 0 deletions agents/ten_packages/extension/coze_python_async/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ async def on_call_chat_completion(
) -> any:
raise RuntimeError("Not implemented")

async def on_generate_image(self, async_ten_env, prompt)->str:
raise RuntimeError("Not implemented")

async def on_data_chat_completion(
self, ten_env: AsyncTenEnv, **kargs: LLMDataCompletionArgs
) -> None:
Expand Down
3 changes: 3 additions & 0 deletions agents/ten_packages/extension/dify_python/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ async def on_video_frame(

async def on_call_chat_completion(self, async_ten_env, **kargs):
raise NotImplementedError

async def on_generate_image(self, async_ten_env, prompt)->str:
return NotImplementedError

async def on_tools_update(self, async_ten_env, tool):
raise NotImplementedError
Expand Down
3 changes: 3 additions & 0 deletions agents/ten_packages/extension/gemini_v2v_python/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,5 +739,8 @@ async def _update_usage(self, usage: dict) -> None:
async def on_call_chat_completion(self, async_ten_env, **kargs):
raise NotImplementedError

async def on_generate_image(self, async_ten_env, prompt)->str:
return NotImplementedError

async def on_data_chat_completion(self, async_ten_env, **kargs):
raise NotImplementedError
3 changes: 3 additions & 0 deletions agents/ten_packages/extension/glue_python_async/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ async def on_call_chat_completion(
) -> any:
raise RuntimeError("Not implemented")

async def on_generate_image(self, async_ten_env, prompt) -> str:
return RuntimeError("Not implemented")

async def on_data_chat_completion(
self, ten_env: AsyncTenEnv, **kargs: LLMDataCompletionArgs
) -> None:
Expand Down
29 changes: 29 additions & 0 deletions agents/ten_packages/extension/image_generate_tool/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# image_generate_tool

<!-- brief introduction for the extension -->

## Features

<!-- main features introduction -->

- xxx feature

## API

Refer to `api` definition in [manifest.json] and default values in [property.json](property.json).

<!-- Additional API.md can be referred to if extra introduction needed -->

## Development

### Build

<!-- build dependencies and steps -->

### Unit test

<!-- how to do unit test for the extension -->

## Misc

<!-- others if applicable -->
6 changes: 6 additions & 0 deletions agents/ten_packages/extension/image_generate_tool/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for more information.
#
from . import addon
20 changes: 20 additions & 0 deletions agents/ten_packages/extension/image_generate_tool/addon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for more information.
#
from ten import (
Addon,
register_addon_as_extension,
TenEnv,
)
from .extension import ImageGenerateToolExtension


@register_addon_as_extension("image_generate_tool")
class ImageGenerateToolExtensionAddon(Addon):

def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None:
ten_env.log_info("on_create_instance")
ten_env.on_create_instance_done(
ImageGenerateToolExtension(name), context)
48 changes: 48 additions & 0 deletions agents/ten_packages/extension/image_generate_tool/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for more information.
#
from ten import (
TenEnv,
AsyncTenEnv,
)
from ten_ai_base import (
AsyncLLMToolBaseExtension, LLMToolMetadata, LLMToolResult, BaseConfig
)
from dataclasses import dataclass


@dataclass
class ImageGenerateToolConfig(BaseConfig):
# TODO: add extra config fields here
pass


class ImageGenerateToolExtension(AsyncLLMToolBaseExtension):
def __init__(self, name: str):
super().__init__(name)
self.config = None

async def on_start(self, ten_env: AsyncTenEnv) -> None:
await super().on_start(ten_env)

# initialize configuration
self.config = await ImageGenerateToolConfig.create_async(ten_env=ten_env)
ten_env.log_info(f"config: {self.config}")

# Implement this method to construct and start your resources.
ten_env.log_debug("TODO: on_start")

async def on_stop(self, ten_env: AsyncTenEnv) -> None:
await super().on_stop(ten_env)

#Implement this method to stop and destruct your resources.
ten_env.log_debug("TODO: on_stop")

def get_tool_metadata(self, ten_env: TenEnv) -> list[LLMToolMetadata]:
ten_env.log_debug("TODO: get_tool_metadata")
return []

async def run_tool(self, ten_env: AsyncTenEnv, name: str, args: dict) -> LLMToolResult | None:
ten_env.log_debug(f"TODO: run_tool {name} {args}")
88 changes: 88 additions & 0 deletions agents/ten_packages/extension/image_generate_tool/manifest.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
{
"type": "extension",
"name": "image_generate_tool",
"version": "0.1.0",
"dependencies": [
{
"type": "system",
"name": "ten_runtime_python",
"version": "0.6"
}
],
"package": {
"include": [
"manifest.json",
"property.json",
"requirements.txt",
"**.tent",
"**.py",
"README.md"
]
},
"api": {
"property": {},
"cmd_in": [
{
"name": "tool_call",
"property": {
"name": {
"type": "string"
},
"arguments": {
"type": "string"
}
},
"required": [
"name"
],
"result": {
"property": {
"tool_result": {
"type": "string"
}
},
"required": [
"tool_result"
]
}
}
],
"cmd_out": [
{
"name": "tool_register",
"property": {
"tool": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"description": {
"type": "string"
},
"parameters": {
"type": "array",
"items": {
"type": "object",
"properties": {}
}
}
},
"required": [
"name",
"description",
"parameters"
]
}
},
"result": {
"property": {
"response": {
"type": "string"
}
}
}
}
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Empty file.
10 changes: 8 additions & 2 deletions agents/ten_packages/extension/openai_chatgpt_python/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)

from .helper import parse_sentences
from .openai import OpenAIChatGPT, OpenAIChatGPTConfig
from .openai import OpenAIChatGPT, OpenAIChatGPTConfig, OpenAIImageConfig
from ten import (
Cmd,
StatusCode,
Expand All @@ -53,6 +53,7 @@ def __init__(self, name: str):
self.memory = []
self.memory_cache = []
self.config = None
self.image_config = None
self.client = None
self.sentence_fragment = ""
self.tool_task_future = None
Expand All @@ -67,6 +68,7 @@ async def on_start(self, async_ten_env: AsyncTenEnv) -> None:
await super().on_start(async_ten_env)

self.config = await OpenAIChatGPTConfig.create_async(ten_env=async_ten_env)
self.image_config = await OpenAIImageConfig.create_async(ten_env=async_ten_env)

# Mandatory properties
if not self.config.api_key:
Expand All @@ -75,7 +77,7 @@ async def on_start(self, async_ten_env: AsyncTenEnv) -> None:

# Create instance
try:
self.client = OpenAIChatGPT(async_ten_env, self.config)
self.client = OpenAIChatGPT(async_ten_env, self.config, self.image_config)
async_ten_env.log_info(
f"initialized with max_tokens: {self.config.max_tokens}, model: {self.config.model}, vendor: {self.config.vendor}"
)
Expand Down Expand Up @@ -147,6 +149,10 @@ async def on_tools_update(
) -> None:
return await super().on_tools_update(async_ten_env, tool)

async def on_generate_image(self, async_ten_env, prompt)->str:
url = await self.client.generate_image(prompt)
return url

async def on_call_chat_completion(
self, async_ten_env: AsyncTenEnv, **kargs: LLMCallCompletionArgs
) -> any:
Expand Down
24 changes: 23 additions & 1 deletion agents/ten_packages/extension/openai_chatgpt_python/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,20 @@ class OpenAIChatGPTConfig(BaseConfig):
azure_endpoint: str = ""
azure_api_version: str = ""

@dataclass
class OpenAIImageConfig(BaseConfig):
model: str = "dall-e-3"
size: str = "512x512"
quality: str = "standard"
n: int = 1


class OpenAIChatGPT:
client = None

def __init__(self, ten_env: AsyncTenEnv, config: OpenAIChatGPTConfig):
def __init__(self, ten_env: AsyncTenEnv, config: OpenAIChatGPTConfig, image_config: OpenAIImageConfig):
self.config = config
self.image_config = image_config
ten_env.log_info(f"OpenAIChatGPT initialized with config: {config.api_key}")
if self.config.vendor == "azure":
self.client = AsyncAzureOpenAI(
Expand Down Expand Up @@ -173,3 +181,17 @@ async def get_chat_completions_stream(self, messages, tools=None, listener=None)
# Emit content finished event after the loop completes
if listener:
listener.emit("content_finished", full_content)


async def generate_image(self, prompt:str):
try:
response = await self.client.images.generate(
prompt=prompt,
model=self.image_config.model,
size=self.image_config.size,
quality=self.image_config.quality,
)
except Exception as e:
raise RuntimeError(f"GenerateImage failed, err: {e}") from e

return response.data[0].url
3 changes: 3 additions & 0 deletions agents/ten_packages/extension/openai_v2v_python/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,5 +830,8 @@ async def _update_usage(self, usage: dict) -> None:
async def on_call_chat_completion(self, async_ten_env, **kargs):
raise NotImplementedError

async def on_generate_image(self, async_ten_env, prompt)->str:
return NotImplementedError

async def on_data_chat_completion(self, async_ten_env, **kargs):
raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
CMD_PROPERTY_TOOL = "tool"
CMD_PROPERTY_RESULT = "tool_result"
CMD_CHAT_COMPLETION_CALL = "chat_completion_call"
CMD_GENERATE_IMAGE_CALL = "generate_image_call"
CMD_IN_FLUSH = "flush"
CMD_OUT_FLUSH = "flush"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def on_cmd(self, async_ten_env: AsyncTenEnv, cmd: Cmd) -> None:
self.available_tools.append(tool_metadata)
await self.on_tools_update(async_ten_env, tool_metadata)
await async_ten_env.return_result(CmdResult.create(StatusCode.OK), cmd)
except Exception as err:
except Exception:
async_ten_env.log_warn(f"on_cmd failed: {traceback.format_exc()}")
await async_ten_env.return_result(
CmdResult.create(StatusCode.ERROR), cmd
Expand Down Expand Up @@ -147,6 +147,12 @@ async def on_data_chat_completion(
Note that this method is stream-based, and it should consider supporting local context caching.
"""

@abstractmethod
async def on_generate_image(
self, async_ten_env: AsyncTenEnv, prompt: str
) -> str:
"""Called when an image generation is requested. Implement this method to process the image generation."""

@abstractmethod
async def on_tools_update(
self, async_ten_env: AsyncTenEnv, tool: LLMToolMetadata
Expand Down

0 comments on commit 4e6aae7

Please sign in to comment.