mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Filter out candidates with the same name but different instructions, … (#925)
* Filter out candidates with the same name but different instructions, file IDs, and function names * polish * improve log * improving log * improve log * Improve function signature (#2) * try to fix ci * try to fix ci --------- Co-authored-by: gagb <gagb@users.noreply.github.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
@@ -45,7 +45,7 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
- tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
|
||||
or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
|
||||
- file_ids: files used by retrieval in run
|
||||
overwrite_instructions (bool): whether to overwrite the instructions of an existing assistant.
|
||||
overwrite_instructions (bool): whether to overwrite the instructions of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
|
||||
kwargs (dict): Additional configuration options for the agent.
|
||||
- verbose (bool): If set to True, enables more detailed output from the assistant thread.
|
||||
- Other kwargs: Except verbose, others are passed directly to ConversableAgent.
|
||||
@@ -59,9 +59,14 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
if openai_assistant_id is None:
|
||||
# try to find assistant by name first
|
||||
candidate_assistants = retrieve_assistants_by_name(self._openai_client, name)
|
||||
if len(candidate_assistants) > 0:
|
||||
# Filter out candidates with the same name but different instructions, file IDs, and function names.
|
||||
candidate_assistants = self.find_matching_assistant(
|
||||
candidate_assistants, instructions, llm_config.get("tools", []), llm_config.get("file_ids", [])
|
||||
)
|
||||
|
||||
if len(candidate_assistants) == 0:
|
||||
logger.warning(f"assistant {name} does not exist, creating a new assistant")
|
||||
logger.warning("No matching assistant found, creating a new assistant")
|
||||
# create a new assistant
|
||||
if instructions is None:
|
||||
logger.warning(
|
||||
@@ -76,11 +81,10 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
file_ids=llm_config.get("file_ids", []),
|
||||
)
|
||||
else:
|
||||
if len(candidate_assistants) > 1:
|
||||
logger.warning(
|
||||
f"Multiple assistants with name {name} found. Using the first assistant in the list. "
|
||||
f"Please specify the assistant ID in llm_config to use a specific assistant."
|
||||
)
|
||||
logger.warning(
|
||||
"Matching assistant found, using the first matching assistant: %s",
|
||||
candidate_assistants[0].__dict__,
|
||||
)
|
||||
self._openai_assistant = candidate_assistants[0]
|
||||
else:
|
||||
# retrieve an existing assistant
|
||||
@@ -368,3 +372,53 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
"""Delete the assistant from OAI assistant API"""
|
||||
logger.warning("Permanently deleting assistant...")
|
||||
self._openai_client.beta.assistants.delete(self.assistant_id)
|
||||
|
||||
def find_matching_assistant(self, candidate_assistants, instructions, tools, file_ids):
|
||||
"""
|
||||
Find the matching assistant from a list of candidate assistants.
|
||||
Filter out candidates with the same name but different instructions, file IDs, and function names.
|
||||
TODO: implement accurate match based on assistant metadata fields.
|
||||
"""
|
||||
matching_assistants = []
|
||||
|
||||
# Preprocess the required tools for faster comparison
|
||||
required_tool_types = set(tool.get("type") for tool in tools)
|
||||
required_function_names = set(
|
||||
tool.get("function", {}).get("name")
|
||||
for tool in tools
|
||||
if tool.get("type") not in ["code_interpreter", "retrieval"]
|
||||
)
|
||||
required_file_ids = set(file_ids) # Convert file_ids to a set for unordered comparison
|
||||
|
||||
for assistant in candidate_assistants:
|
||||
# Check if instructions are similar
|
||||
if instructions and instructions != getattr(assistant, "instructions", None):
|
||||
logger.warning(
|
||||
"instructions not match, skip assistant(%s): %s",
|
||||
assistant.id,
|
||||
getattr(assistant, "instructions", None),
|
||||
)
|
||||
continue
|
||||
|
||||
# Preprocess the assistant's tools
|
||||
assistant_tool_types = set(tool.type for tool in assistant.tools)
|
||||
assistant_function_names = set(tool.function.name for tool in assistant.tools if hasattr(tool, "function"))
|
||||
assistant_file_ids = set(getattr(assistant, "file_ids", [])) # Convert to set for comparison
|
||||
|
||||
# Check if the tool types, function names, and file IDs match
|
||||
if required_tool_types != assistant_tool_types or required_function_names != assistant_function_names:
|
||||
logger.warning(
|
||||
"tools not match, skip assistant(%s): tools %s, functions %s",
|
||||
assistant.id,
|
||||
assistant_tool_types,
|
||||
assistant_function_names,
|
||||
)
|
||||
continue
|
||||
if required_file_ids != assistant_file_ids:
|
||||
logger.warning("file_ids not match, skip assistant(%s): %s", assistant.id, assistant_file_ids)
|
||||
continue
|
||||
|
||||
# Append assistant to matching list if all conditions are met
|
||||
matching_assistants.append(assistant)
|
||||
|
||||
return matching_assistants
|
||||
|
||||
@@ -6,6 +6,15 @@ from typing import List, Optional, Dict, Set, Union
|
||||
import logging
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
from openai.types.beta.assistant import Assistant
|
||||
|
||||
ERROR = None
|
||||
except ImportError:
|
||||
ERROR = ImportError("Please install openai>=1 to use autogen.OpenAIWrapper.")
|
||||
OpenAI = object
|
||||
Assistant = object
|
||||
|
||||
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
|
||||
|
||||
@@ -413,10 +422,12 @@ def config_list_from_dotenv(
|
||||
return config_list
|
||||
|
||||
|
||||
def retrieve_assistants_by_name(client, name) -> str:
|
||||
def retrieve_assistants_by_name(client: OpenAI, name: str) -> List[Assistant]:
|
||||
"""
|
||||
Return the assistants with the given name from OAI assistant API
|
||||
"""
|
||||
if ERROR:
|
||||
raise ERROR
|
||||
assistants = client.beta.assistants.list()
|
||||
candidate_assistants = []
|
||||
for assistant in assistants.data:
|
||||
|
||||
@@ -222,17 +222,46 @@ def test_assistant_retrieval():
|
||||
|
||||
name = "For test_assistant_retrieval"
|
||||
|
||||
function_1_schema = {
|
||||
"name": "call_function_1",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
"description": "This is a test function 1",
|
||||
}
|
||||
function_2_schema = {
|
||||
"name": "call_function_1",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
"description": "This is a test function 2",
|
||||
}
|
||||
|
||||
openai_client = OpenAIWrapper(config_list=config_list)._clients[0]
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
|
||||
file_2 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
|
||||
|
||||
all_llm_config = {
|
||||
"tools": [
|
||||
{"type": "function", "function": function_1_schema},
|
||||
{"type": "function", "function": function_2_schema},
|
||||
{"type": "retrieval"},
|
||||
{"type": "code_interpreter"},
|
||||
],
|
||||
"file_ids": [file_1.id, file_2.id],
|
||||
"config_list": config_list,
|
||||
}
|
||||
|
||||
name = "For test_gpt_assistant_chat"
|
||||
|
||||
assistant_first = GPTAssistantAgent(
|
||||
name,
|
||||
instructions="This is a test",
|
||||
llm_config={"config_list": config_list},
|
||||
llm_config=all_llm_config,
|
||||
)
|
||||
candidate_first = retrieve_assistants_by_name(assistant_first.openai_client, name)
|
||||
|
||||
assistant_second = GPTAssistantAgent(
|
||||
name,
|
||||
instructions="This is a test",
|
||||
llm_config={"config_list": config_list},
|
||||
llm_config=all_llm_config,
|
||||
)
|
||||
candidate_second = retrieve_assistants_by_name(assistant_second.openai_client, name)
|
||||
|
||||
@@ -243,7 +272,125 @@ def test_assistant_retrieval():
|
||||
# Not found error is expected because the same assistant can not be deleted twice
|
||||
pass
|
||||
|
||||
openai_client.files.delete(file_1.id)
|
||||
openai_client.files.delete(file_2.id)
|
||||
|
||||
assert candidate_first == candidate_second
|
||||
assert len(candidate_first) == 1
|
||||
|
||||
candidates = retrieve_assistants_by_name(openai_client, name)
|
||||
assert len(candidates) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform in ["darwin", "win32"] or skip_test,
|
||||
reason="do not run on MacOS or windows or dependency is not installed",
|
||||
)
|
||||
def test_assistant_mismatch_retrieval():
|
||||
"""Test function to check if the GPTAssistantAgent can filter out the mismatch assistant"""
|
||||
|
||||
name = "For test_assistant_retrieval"
|
||||
|
||||
function_1_schema = {
|
||||
"name": "call_function",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
"description": "This is a test function 1",
|
||||
}
|
||||
function_2_schema = {
|
||||
"name": "call_function",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
"description": "This is a test function 2",
|
||||
}
|
||||
function_3_schema = {
|
||||
"name": "call_function_other",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
"description": "This is a test function 3",
|
||||
}
|
||||
|
||||
openai_client = OpenAIWrapper(config_list=config_list)._clients[0]
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
|
||||
file_2 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
|
||||
|
||||
all_llm_config = {
|
||||
"tools": [
|
||||
{"type": "function", "function": function_1_schema},
|
||||
{"type": "function", "function": function_2_schema},
|
||||
{"type": "retrieval"},
|
||||
{"type": "code_interpreter"},
|
||||
],
|
||||
"file_ids": [file_1.id, file_2.id],
|
||||
"config_list": config_list,
|
||||
}
|
||||
|
||||
name = "For test_gpt_assistant_chat"
|
||||
|
||||
assistant_first = GPTAssistantAgent(
|
||||
name,
|
||||
instructions="This is a test",
|
||||
llm_config=all_llm_config,
|
||||
)
|
||||
candidate_first = retrieve_assistants_by_name(assistant_first.openai_client, name)
|
||||
assert len(candidate_first) == 1
|
||||
|
||||
# test instructions mismatch
|
||||
assistant_instructions_mistaching = GPTAssistantAgent(
|
||||
name,
|
||||
instructions="This is a test for mismatch instructions",
|
||||
llm_config=all_llm_config,
|
||||
)
|
||||
candidate_instructions_mistaching = retrieve_assistants_by_name(
|
||||
assistant_instructions_mistaching.openai_client, name
|
||||
)
|
||||
assert len(candidate_instructions_mistaching) == 2
|
||||
|
||||
# test mismatch fild ids
|
||||
file_ids_mismatch_llm_config = {
|
||||
"tools": [
|
||||
{"type": "code_interpreter"},
|
||||
{"type": "retrieval"},
|
||||
{"type": "function", "function": function_2_schema},
|
||||
{"type": "function", "function": function_1_schema},
|
||||
],
|
||||
"file_ids": [file_2.id],
|
||||
"config_list": config_list,
|
||||
}
|
||||
assistant_file_ids_mismatch = GPTAssistantAgent(
|
||||
name,
|
||||
instructions="This is a test",
|
||||
llm_config=file_ids_mismatch_llm_config,
|
||||
)
|
||||
candidate_file_ids_mismatch = retrieve_assistants_by_name(assistant_file_ids_mismatch.openai_client, name)
|
||||
assert len(candidate_file_ids_mismatch) == 3
|
||||
|
||||
# test tools mismatch
|
||||
tools_mismatch_llm_config = {
|
||||
"tools": [
|
||||
{"type": "code_interpreter"},
|
||||
{"type": "retrieval"},
|
||||
{"type": "function", "function": function_3_schema},
|
||||
],
|
||||
"file_ids": [file_2.id, file_1.id],
|
||||
"config_list": config_list,
|
||||
}
|
||||
assistant_tools_mistaching = GPTAssistantAgent(
|
||||
name,
|
||||
instructions="This is a test",
|
||||
llm_config=tools_mismatch_llm_config,
|
||||
)
|
||||
candidate_tools_mismatch = retrieve_assistants_by_name(assistant_tools_mistaching.openai_client, name)
|
||||
assert len(candidate_tools_mismatch) == 4
|
||||
|
||||
openai_client.files.delete(file_1.id)
|
||||
openai_client.files.delete(file_2.id)
|
||||
|
||||
assistant_first.delete_assistant()
|
||||
assistant_instructions_mistaching.delete_assistant()
|
||||
assistant_file_ids_mismatch.delete_assistant()
|
||||
assistant_tools_mistaching.delete_assistant()
|
||||
|
||||
candidates = retrieve_assistants_by_name(openai_client, name)
|
||||
assert len(candidates) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -252,3 +399,4 @@ if __name__ == "__main__":
|
||||
test_gpt_assistant_instructions_overwrite()
|
||||
test_gpt_assistant_existing_no_instructions()
|
||||
test_get_assistant_files()
|
||||
test_assistant_mismatch_retrieval()
|
||||
|
||||
Reference in New Issue
Block a user