Skip to content

Commit

Permalink
Filter out candidates with the same name but different instructions, … (
Browse files Browse the repository at this point in the history
microsoft#925)

* Filter out candidates with the same name but different instructions, file IDs, and function names

* polish

* improve log

* improving log

* improve log

* Improve function signature (#2)

* try to fix ci

* try to fix ci

---------

Co-authored-by: gagb <[email protected]>
Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
3 people authored Dec 27, 2023
1 parent dd516f2 commit d583ad8
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 10 deletions.
68 changes: 61 additions & 7 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
- tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
- file_ids: files used by retrieval in run
overwrite_instructions (bool): whether to overwrite the instructions of an existing assistant.
overwrite_instructions (bool): whether to overwrite the instructions of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
kwargs (dict): Additional configuration options for the agent.
- verbose (bool): If set to True, enables more detailed output from the assistant thread.
- Other kwargs: Except verbose, others are passed directly to ConversableAgent.
Expand All @@ -59,9 +59,14 @@ def __init__(
if openai_assistant_id is None:
# try to find assistant by name first
candidate_assistants = retrieve_assistants_by_name(self._openai_client, name)
if len(candidate_assistants) > 0:
# Filter out candidates with the same name but different instructions, file IDs, and function names.
candidate_assistants = self.find_matching_assistant(
candidate_assistants, instructions, llm_config.get("tools", []), llm_config.get("file_ids", [])
)

if len(candidate_assistants) == 0:
logger.warning(f"assistant {name} does not exist, creating a new assistant")
logger.warning("No matching assistant found, creating a new assistant")
# create a new assistant
if instructions is None:
logger.warning(
Expand All @@ -76,11 +81,10 @@ def __init__(
file_ids=llm_config.get("file_ids", []),
)
else:
if len(candidate_assistants) > 1:
logger.warning(
f"Multiple assistants with name {name} found. Using the first assistant in the list. "
f"Please specify the assistant ID in llm_config to use a specific assistant."
)
logger.warning(
"Matching assistant found, using the first matching assistant: %s",
candidate_assistants[0].__dict__,
)
self._openai_assistant = candidate_assistants[0]
else:
# retrieve an existing assistant
Expand Down Expand Up @@ -368,3 +372,53 @@ def delete_assistant(self):
"""Delete the assistant from OAI assistant API"""
logger.warning("Permanently deleting assistant...")
self._openai_client.beta.assistants.delete(self.assistant_id)

def find_matching_assistant(self, candidate_assistants, instructions, tools, file_ids):
"""
Find the matching assistant from a list of candidate assistants.
Filter out candidates with the same name but different instructions, file IDs, and function names.
TODO: implement accurate match based on assistant metadata fields.
"""
matching_assistants = []

# Preprocess the required tools for faster comparison
required_tool_types = set(tool.get("type") for tool in tools)
required_function_names = set(
tool.get("function", {}).get("name")
for tool in tools
if tool.get("type") not in ["code_interpreter", "retrieval"]
)
required_file_ids = set(file_ids) # Convert file_ids to a set for unordered comparison

for assistant in candidate_assistants:
# Check if instructions are similar
if instructions and instructions != getattr(assistant, "instructions", None):
logger.warning(
"instructions not match, skip assistant(%s): %s",
assistant.id,
getattr(assistant, "instructions", None),
)
continue

# Preprocess the assistant's tools
assistant_tool_types = set(tool.type for tool in assistant.tools)
assistant_function_names = set(tool.function.name for tool in assistant.tools if hasattr(tool, "function"))
assistant_file_ids = set(getattr(assistant, "file_ids", [])) # Convert to set for comparison

# Check if the tool types, function names, and file IDs match
if required_tool_types != assistant_tool_types or required_function_names != assistant_function_names:
logger.warning(
"tools not match, skip assistant(%s): tools %s, functions %s",
assistant.id,
assistant_tool_types,
assistant_function_names,
)
continue
if required_file_ids != assistant_file_ids:
logger.warning("file_ids not match, skip assistant(%s): %s", assistant.id, assistant_file_ids)
continue

# Append assistant to matching list if all conditions are met
matching_assistants.append(assistant)

return matching_assistants
13 changes: 12 additions & 1 deletion autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
import logging
from dotenv import find_dotenv, load_dotenv

try:
from openai import OpenAI
from openai.types.beta.assistant import Assistant

ERROR = None
except ImportError:
ERROR = ImportError("Please install openai>=1 to use autogen.OpenAIWrapper.")
OpenAI = object
Assistant = object

NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]

Expand Down Expand Up @@ -413,10 +422,12 @@ def config_list_from_dotenv(
return config_list


def retrieve_assistants_by_name(client, name) -> str:
def retrieve_assistants_by_name(client: OpenAI, name: str) -> List[Assistant]:
"""
Return the assistants with the given name from OAI assistant API
"""
if ERROR:
raise ERROR
assistants = client.beta.assistants.list()
candidate_assistants = []
for assistant in assistants.data:
Expand Down
152 changes: 150 additions & 2 deletions test/agentchat/contrib/test_gpt_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,17 +222,46 @@ def test_assistant_retrieval():

name = "For test_assistant_retrieval"

function_1_schema = {
"name": "call_function_1",
"parameters": {"type": "object", "properties": {}, "required": []},
"description": "This is a test function 1",
}
function_2_schema = {
"name": "call_function_1",
"parameters": {"type": "object", "properties": {}, "required": []},
"description": "This is a test function 2",
}

openai_client = OpenAIWrapper(config_list=config_list)._clients[0]
current_file_path = os.path.abspath(__file__)
file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
file_2 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")

all_llm_config = {
"tools": [
{"type": "function", "function": function_1_schema},
{"type": "function", "function": function_2_schema},
{"type": "retrieval"},
{"type": "code_interpreter"},
],
"file_ids": [file_1.id, file_2.id],
"config_list": config_list,
}

name = "For test_gpt_assistant_chat"

assistant_first = GPTAssistantAgent(
name,
instructions="This is a test",
llm_config={"config_list": config_list},
llm_config=all_llm_config,
)
candidate_first = retrieve_assistants_by_name(assistant_first.openai_client, name)

assistant_second = GPTAssistantAgent(
name,
instructions="This is a test",
llm_config={"config_list": config_list},
llm_config=all_llm_config,
)
candidate_second = retrieve_assistants_by_name(assistant_second.openai_client, name)

Expand All @@ -243,7 +272,125 @@ def test_assistant_retrieval():
# Not found error is expected because the same assistant can not be deleted twice
pass

openai_client.files.delete(file_1.id)
openai_client.files.delete(file_2.id)

assert candidate_first == candidate_second
assert len(candidate_first) == 1

candidates = retrieve_assistants_by_name(openai_client, name)
assert len(candidates) == 0


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip_test,
reason="do not run on MacOS or windows or dependency is not installed",
)
def test_assistant_mismatch_retrieval():
"""Test function to check if the GPTAssistantAgent can filter out the mismatch assistant"""

name = "For test_assistant_retrieval"

function_1_schema = {
"name": "call_function",
"parameters": {"type": "object", "properties": {}, "required": []},
"description": "This is a test function 1",
}
function_2_schema = {
"name": "call_function",
"parameters": {"type": "object", "properties": {}, "required": []},
"description": "This is a test function 2",
}
function_3_schema = {
"name": "call_function_other",
"parameters": {"type": "object", "properties": {}, "required": []},
"description": "This is a test function 3",
}

openai_client = OpenAIWrapper(config_list=config_list)._clients[0]
current_file_path = os.path.abspath(__file__)
file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
file_2 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")

all_llm_config = {
"tools": [
{"type": "function", "function": function_1_schema},
{"type": "function", "function": function_2_schema},
{"type": "retrieval"},
{"type": "code_interpreter"},
],
"file_ids": [file_1.id, file_2.id],
"config_list": config_list,
}

name = "For test_gpt_assistant_chat"

assistant_first = GPTAssistantAgent(
name,
instructions="This is a test",
llm_config=all_llm_config,
)
candidate_first = retrieve_assistants_by_name(assistant_first.openai_client, name)
assert len(candidate_first) == 1

# test instructions mismatch
assistant_instructions_mistaching = GPTAssistantAgent(
name,
instructions="This is a test for mismatch instructions",
llm_config=all_llm_config,
)
candidate_instructions_mistaching = retrieve_assistants_by_name(
assistant_instructions_mistaching.openai_client, name
)
assert len(candidate_instructions_mistaching) == 2

# test mismatch fild ids
file_ids_mismatch_llm_config = {
"tools": [
{"type": "code_interpreter"},
{"type": "retrieval"},
{"type": "function", "function": function_2_schema},
{"type": "function", "function": function_1_schema},
],
"file_ids": [file_2.id],
"config_list": config_list,
}
assistant_file_ids_mismatch = GPTAssistantAgent(
name,
instructions="This is a test",
llm_config=file_ids_mismatch_llm_config,
)
candidate_file_ids_mismatch = retrieve_assistants_by_name(assistant_file_ids_mismatch.openai_client, name)
assert len(candidate_file_ids_mismatch) == 3

# test tools mismatch
tools_mismatch_llm_config = {
"tools": [
{"type": "code_interpreter"},
{"type": "retrieval"},
{"type": "function", "function": function_3_schema},
],
"file_ids": [file_2.id, file_1.id],
"config_list": config_list,
}
assistant_tools_mistaching = GPTAssistantAgent(
name,
instructions="This is a test",
llm_config=tools_mismatch_llm_config,
)
candidate_tools_mismatch = retrieve_assistants_by_name(assistant_tools_mistaching.openai_client, name)
assert len(candidate_tools_mismatch) == 4

openai_client.files.delete(file_1.id)
openai_client.files.delete(file_2.id)

assistant_first.delete_assistant()
assistant_instructions_mistaching.delete_assistant()
assistant_file_ids_mismatch.delete_assistant()
assistant_tools_mistaching.delete_assistant()

candidates = retrieve_assistants_by_name(openai_client, name)
assert len(candidates) == 0


if __name__ == "__main__":
Expand All @@ -252,3 +399,4 @@ def test_assistant_retrieval():
test_gpt_assistant_instructions_overwrite()
test_gpt_assistant_existing_no_instructions()
test_get_assistant_files()
test_assistant_mismatch_retrieval()

0 comments on commit d583ad8

Please sign in to comment.