mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Fix/async function and tool execution (#87)
* async run group chat * conversible agent allow async functions to generate reply * test for async execution --------- Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu> Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
@@ -126,6 +126,7 @@ class ConversableAgent(Agent):
|
||||
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
|
||||
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
|
||||
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
|
||||
self.register_reply([Agent, None], ConversableAgent.generate_async_function_call_reply)
|
||||
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
|
||||
|
||||
def register_reply(
|
||||
@@ -661,6 +662,28 @@ class ConversableAgent(Agent):
|
||||
return True, func_return
|
||||
return False, None
|
||||
|
||||
async def generate_async_function_call_reply(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[Any] = None,
|
||||
):
|
||||
"""Generate a reply using async function call."""
|
||||
if config is None:
|
||||
config = self
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
message = messages[-1]
|
||||
if "function_call" in message:
|
||||
func_call = message["function_call"]
|
||||
func_name = func_call.get("name", "")
|
||||
func = self._function_map.get(func_name, None)
|
||||
if func and asyncio.coroutines.iscoroutinefunction(func):
|
||||
_, func_return = await self.a_execute_function(func_call)
|
||||
return True, func_return
|
||||
|
||||
return False, None
|
||||
|
||||
def check_termination_and_human_reply(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
@@ -1002,6 +1025,56 @@ class ConversableAgent(Agent):
|
||||
"content": str(content),
|
||||
}
|
||||
|
||||
async def a_execute_function(self, func_call):
|
||||
"""Execute an async function call and return the result.
|
||||
|
||||
Override this function to modify the way async functions are executed.
|
||||
|
||||
Args:
|
||||
func_call: a dictionary extracted from openai message at key "function_call" with keys "name" and "arguments".
|
||||
|
||||
Returns:
|
||||
A tuple of (is_exec_success, result_dict).
|
||||
is_exec_success (boolean): whether the execution is successful.
|
||||
result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function".
|
||||
"""
|
||||
func_name = func_call.get("name", "")
|
||||
func = self._function_map.get(func_name, None)
|
||||
|
||||
is_exec_success = False
|
||||
if func is not None:
|
||||
# Extract arguments from a json-like string and put it into a dict.
|
||||
input_string = self._format_json_str(func_call.get("arguments", "{}"))
|
||||
try:
|
||||
arguments = json.loads(input_string)
|
||||
except json.JSONDecodeError as e:
|
||||
arguments = None
|
||||
content = f"Error: {e}\n You argument should follow json format."
|
||||
|
||||
# Try to execute the function
|
||||
if arguments is not None:
|
||||
print(
|
||||
colored(f"\n>>>>>>>> EXECUTING ASYNC FUNCTION {func_name}...", "magenta"),
|
||||
flush=True,
|
||||
)
|
||||
try:
|
||||
if asyncio.coroutines.iscoroutinefunction(func):
|
||||
content = await func(**arguments)
|
||||
else:
|
||||
# Fallback to sync function if the function is not async
|
||||
content = func(**arguments)
|
||||
is_exec_success = True
|
||||
except Exception as e:
|
||||
content = f"Error: {e}"
|
||||
else:
|
||||
content = f"Error: Function {func_name} not found."
|
||||
|
||||
return is_exec_success, {
|
||||
"name": func_name,
|
||||
"role": "function",
|
||||
"content": str(content),
|
||||
}
|
||||
|
||||
def generate_init_message(self, **context) -> Union[str, Dict]:
|
||||
"""Generate the initial message for the agent.
|
||||
|
||||
|
||||
@@ -130,7 +130,12 @@ class GroupChatManager(ConversableAgent):
|
||||
system_message=system_message,
|
||||
**kwargs,
|
||||
)
|
||||
# Order of register_reply is important.
|
||||
# Allow sync chat if initiated using initiate_chat
|
||||
self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset)
|
||||
# Allow async chat if initiated using a_initiate_chat
|
||||
self.register_reply(Agent, GroupChatManager.a_run_chat, config=groupchat, reset_config=GroupChat.reset)
|
||||
|
||||
# self._random = random.Random(seed)
|
||||
|
||||
def run_chat(
|
||||
@@ -177,3 +182,48 @@ class GroupChatManager(ConversableAgent):
|
||||
speaker.send(reply, self, request_reply=False)
|
||||
message = self.last_message(speaker)
|
||||
return True, None
|
||||
|
||||
async def a_run_chat(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[GroupChat] = None,
|
||||
):
|
||||
"""Run a group chat asynchronously."""
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
message = messages[-1]
|
||||
speaker = sender
|
||||
groupchat = config
|
||||
for i in range(groupchat.max_round):
|
||||
# set the name to speaker's name if the role is not function
|
||||
if message["role"] != "function":
|
||||
message["name"] = speaker.name
|
||||
groupchat.messages.append(message)
|
||||
# broadcast the message to all agents except the speaker
|
||||
for agent in groupchat.agents:
|
||||
if agent != speaker:
|
||||
await self.a_send(message, agent, request_reply=False, silent=True)
|
||||
if i == groupchat.max_round - 1:
|
||||
# the last round
|
||||
break
|
||||
try:
|
||||
# select the next speaker
|
||||
speaker = groupchat.select_speaker(speaker, self)
|
||||
# let the speaker speak
|
||||
reply = await speaker.a_generate_reply(sender=self)
|
||||
except KeyboardInterrupt:
|
||||
# let the admin agent speak if interrupted
|
||||
if groupchat.admin_name in groupchat.agent_names:
|
||||
# admin agent is one of the participants
|
||||
speaker = groupchat.agent_by_name(groupchat.admin_name)
|
||||
reply = await speaker.a_generate_reply(sender=self)
|
||||
else:
|
||||
# admin agent is not found in the participants
|
||||
raise
|
||||
if reply is None:
|
||||
break
|
||||
# The speaker sends the message without requesting a reply
|
||||
await speaker.a_send(reply, self, request_reply=False)
|
||||
message = self.last_message(speaker)
|
||||
return True, None
|
||||
|
||||
@@ -127,7 +127,68 @@ def test_execute_function():
|
||||
assert user.execute_function(func_call)[1]["content"] == "42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_a_execute_function():
|
||||
from autogen.agentchat import UserProxyAgent
|
||||
import time
|
||||
|
||||
# Create an async function
|
||||
async def add_num(num_to_be_added):
|
||||
given_num = 10
|
||||
time.sleep(1)
|
||||
return num_to_be_added + given_num
|
||||
|
||||
user = UserProxyAgent(name="test", function_map={"add_num": add_num})
|
||||
correct_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}
|
||||
|
||||
# Asset coroutine doesn't match.
|
||||
assert user.execute_function(func_call=correct_args)[1]["content"] != "15"
|
||||
# Asset awaited coroutine does match.
|
||||
assert (await user.a_execute_function(func_call=correct_args))[1]["content"] == "15"
|
||||
|
||||
# function name called is wrong or doesn't exist
|
||||
wrong_func_name = {"name": "subtract_num", "arguments": '{ "num_to_be_added": 5 }'}
|
||||
assert "Error: Function" in (await user.a_execute_function(func_call=wrong_func_name))[1]["content"]
|
||||
|
||||
# arguments passed is not in correct json format
|
||||
wrong_json_format = {
|
||||
"name": "add_num",
|
||||
"arguments": '{ "num_to_be_added": 5, given_num: 10 }',
|
||||
} # should be "given_num" with quotes
|
||||
assert (
|
||||
"You argument should follow json format."
|
||||
in (await user.a_execute_function(func_call=wrong_json_format))[1]["content"]
|
||||
)
|
||||
|
||||
# function execution error with wrong arguments passed
|
||||
wrong_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5, "given_num": 10 }'}
|
||||
assert "Error: " in (await user.a_execute_function(func_call=wrong_args))[1]["content"]
|
||||
|
||||
# 2. test calling a class method
|
||||
class AddNum:
|
||||
def __init__(self, given_num):
|
||||
self.given_num = given_num
|
||||
|
||||
def add(self, num_to_be_added):
|
||||
self.given_num = num_to_be_added + self.given_num
|
||||
return self.given_num
|
||||
|
||||
user = UserProxyAgent(name="test", function_map={"add_num": AddNum(given_num=10).add})
|
||||
func_call = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}
|
||||
assert (await user.a_execute_function(func_call=func_call))[1]["content"] == "15"
|
||||
assert (await user.a_execute_function(func_call=func_call))[1]["content"] == "20"
|
||||
|
||||
# 3. test calling a function with no arguments
|
||||
def get_number():
|
||||
return 42
|
||||
|
||||
user = UserProxyAgent("user", function_map={"get_number": get_number})
|
||||
func_call = {"name": "get_number", "arguments": "{}"}
|
||||
assert (await user.a_execute_function(func_call))[1]["content"] == "42"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_json_extraction()
|
||||
test_execute_function()
|
||||
test_a_execute_function()
|
||||
test_eval_math_responses()
|
||||
|
||||
Reference in New Issue
Block a user