-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathfunctions_based.py
146 lines (118 loc) · 6.12 KB
/
functions_based.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import json
from typing import List
from steamship import Block, MimeTypes, Tag
from steamship.agents.functional.output_parser import FunctionsBasedOutputParser
from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, FinishAction, Tool
from steamship.data.tags.tag_constants import RoleTag, TagKind, TagValueKey
from steamship.data.tags.tag_utils import get_tag
class FunctionsBasedAgent(ChatAgent):
"""Selects actions for AgentService based on OpenAI Function style LLM Prompting."""
PROMPT = """You are a helpful AI assistant.
NOTE: Some functions return images, video, and audio files. These multimedia files will be represented in messages as
UUIDs for Steamship Blocks. When responding directly to a user, you SHOULD print the Steamship Blocks for the images,
video, or audio as follows: `Block(UUID for the block)`.
Example response for a request that generated an image:
Here is the image you requested: Block(288A2CA1-4753-4298-9716-53C1E42B726B).
Only use the functions you have been provided with."""
def __init__(self, tools: List[Tool], llm: ChatLLM, **kwargs):
super().__init__(
output_parser=FunctionsBasedOutputParser(tools=tools), llm=llm, tools=tools, **kwargs
)
def _get_or_create_system_message(self, context: AgentContext) -> Block:
sys_msg = context.chat_history.last_system_message
if sys_msg:
return sys_msg
return context.chat_history.append_system_message(text=self.PROMPT, mime_type=MimeTypes.TXT)
def build_chat_history_for_tool(self, context: AgentContext) -> List[Block]:
messages: List[Block] = [self._get_or_create_system_message(context)]
# get system message
messages_from_memory = []
# get prior conversations
if context.chat_history.is_searchable():
messages_from_memory.extend(
context.chat_history.search(context.chat_history.last_user_message.text, k=3)
.wait()
.to_ranked_blocks()
)
# TODO(dougreid): we need a way to threshold message inclusion, especially for small contexts
# get most recent context
messages_from_memory.extend(context.chat_history.select_messages(self.message_selector))
# de-dupe the messages from memory
ids = [
context.chat_history.last_user_message.id
] # filter out last user message, it is appended afterwards
for msg in messages_from_memory:
if msg.id not in ids:
messages.append(msg)
ids.append(msg.id)
# TODO(dougreid): sort by dates? we SHOULD ensure ordering, given semantic search
# put the user prompt in the appropriate message location
# this should happen BEFORE any agent/assistant messages related to tool selection
messages.append(context.chat_history.last_user_message)
# get working history (completed actions)
messages.extend(self._function_calls_since_last_user_message(context))
return messages
def next_action(self, context: AgentContext) -> Action:
# Build the Chat History that we'll provide as input to the action
messages = self.build_chat_history_for_tool(context)
# Run the default LLM on those messages
output_blocks = self.llm.chat(messages=messages, tools=self.tools)
future_action = self.output_parser.parse(output_blocks[0].text, context)
if not isinstance(future_action, FinishAction):
# record the LLM's function response in history
self._record_action_selection(future_action, context)
return future_action
def _function_calls_since_last_user_message(self, context: AgentContext) -> List[Block]:
function_calls = []
for block in context.chat_history.messages[::-1]: # is this too inefficient at scale?
if block.chat_role == RoleTag.USER:
return reversed(function_calls)
if get_tag(block.tags, kind=TagKind.ROLE, name=RoleTag.FUNCTION):
function_calls.append(block)
elif get_tag(block.tags, kind=TagKind.FUNCTION_SELECTION):
function_calls.append(block)
return reversed(function_calls)
def _to_openai_function_selection(self, action: Action) -> str:
fc = {"name": action.tool}
args = {}
for block in action.input:
for t in block.tags:
if t.kind == TagKind.FUNCTION_ARG:
args[t.name] = block.as_llm_input(exclude_block_wrapper=True)
fc["arguments"] = json.dumps(args) # the arguments must be a string value NOT a dict
return json.dumps(fc)
def _record_action_selection(self, action: Action, context: AgentContext):
tags = [
Tag(kind=TagKind.ROLE, name=RoleTag.ASSISTANT),
Tag(kind=TagKind.FUNCTION_SELECTION, name=action.tool),
Tag(
kind="request-id",
name=context.request_id,
value={TagValueKey.STRING_VALUE: context.request_id},
),
]
context.chat_history.file.append_block(
text=self._to_openai_function_selection(action), tags=tags, mime_type=MimeTypes.TXT
)
def record_action_run(self, action: Action, context: AgentContext):
super().record_action_run(action, context)
tags = [
Tag(
kind=TagKind.ROLE,
name=RoleTag.FUNCTION,
value={TagValueKey.STRING_VALUE: action.tool},
),
Tag(
kind="request-id",
name=context.request_id,
value={TagValueKey.STRING_VALUE: context.request_id},
),
]
# TODO(dougreid): I'm not convinced this is correct for tools that return multiple values.
# It _feels_ like these should be named and inlined as a single message in history, etc.
for block in action.output:
context.chat_history.file.append_block(
text=block.as_llm_input(exclude_block_wrapper=True),
tags=tags,
mime_type=block.mime_type,
)