mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-26 06:47:56 -05:00
function call filter in group chat (#294)
* function call filter in group chat * find agents with function_map
This commit is contained in:
@@ -1017,3 +1017,12 @@ class ConversableAgent(Agent):
|
||||
function_map: a dictionary mapping function names to functions.
|
||||
"""
|
||||
self._function_map.update(function_map)
|
||||
|
||||
def can_execute_function(self, name: str) -> bool:
|
||||
"""Whether the agent can execute the function."""
|
||||
return name in self._function_map
|
||||
|
||||
@property
|
||||
def function_map(self) -> Dict[str, Callable]:
|
||||
"""Return the function map."""
|
||||
return self._function_map
|
||||
|
||||
@@ -10,12 +10,23 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class GroupChat:
|
||||
"""A group chat class that contains a list of agents and the maximum number of rounds."""
|
||||
"""A group chat class that contains the following data fields:
|
||||
- agents: a list of participating agents.
|
||||
- messages: a list of messages in the group chat.
|
||||
- max_round: the maximum number of rounds.
|
||||
- admin_name: the name of the admin agent if there is one. Default is "Admin".
|
||||
KeyBoardInterrupt will make the admin agent take over.
|
||||
- func_call_filter: whether to enforce function call filter. Default is True.
|
||||
When set to True and when a message is a function call suggestion,
|
||||
the next speaker will be chosen from an agent which contains the corresponding function name
|
||||
in its `function_map`.
|
||||
"""
|
||||
|
||||
agents: List[Agent]
|
||||
messages: List[Dict]
|
||||
max_round: int = 10
|
||||
admin_name: str = "Admin" # the name of the admin agent
|
||||
admin_name: str = "Admin"
|
||||
func_call_filter: bool = True
|
||||
|
||||
@property
|
||||
def agent_names(self) -> List[str]:
|
||||
@@ -30,45 +41,69 @@ class GroupChat:
|
||||
"""Find the next speaker based on the message."""
|
||||
return self.agents[self.agent_names.index(name)]
|
||||
|
||||
def next_agent(self, agent: Agent) -> Agent:
|
||||
def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
|
||||
"""Return the next agent in the list."""
|
||||
return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)]
|
||||
if agents == self.agents:
|
||||
return agents[(self.agent_names.index(agent.name) + 1) % len(agents)]
|
||||
else:
|
||||
offset = self.agent_names.index(agent.name) + 1
|
||||
for i in range(len(self.agents)):
|
||||
if self.agents[(offset + i) % len(self.agents)] in agents:
|
||||
return self.agents[(offset + i) % len(self.agents)]
|
||||
|
||||
def select_speaker_msg(self):
|
||||
def select_speaker_msg(self, agents: List[Agent]):
|
||||
"""Return the message for selecting the next speaker."""
|
||||
return f"""You are in a role play game. The following roles are available:
|
||||
{self._participant_roles()}.
|
||||
|
||||
Read the following conversation.
|
||||
Then select the next role from {self.agent_names} to play. Only return the role."""
|
||||
Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""
|
||||
|
||||
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
|
||||
"""Select the next speaker."""
|
||||
selector.update_system_message(self.select_speaker_msg())
|
||||
|
||||
# Warn if GroupChat is underpopulated, without established changing behavior
|
||||
n_agents = len(self.agent_names)
|
||||
if n_agents < 3:
|
||||
logger.warning(
|
||||
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
|
||||
)
|
||||
|
||||
if self.func_call_filter and self.messages and "function_call" in self.messages[-1]:
|
||||
# find agents with the right function_map which contains the function name
|
||||
agents = [
|
||||
agent for agent in self.agents if agent.can_execute_function(self.messages[-1]["function_call"]["name"])
|
||||
]
|
||||
if len(agents) == 1:
|
||||
# only one agent can execute the function
|
||||
return agents[0]
|
||||
elif not agents:
|
||||
# find all the agents with function_map
|
||||
agents = [agent for agent in self.agents if agent.function_map]
|
||||
if len(agents) == 1:
|
||||
return agents[0]
|
||||
elif not agents:
|
||||
raise ValueError(
|
||||
f"No agent can execute the function {self.messages[-1]['name']}. "
|
||||
"Please check the function_map of the agents."
|
||||
)
|
||||
else:
|
||||
agents = self.agents
|
||||
# Warn if GroupChat is underpopulated
|
||||
n_agents = len(agents)
|
||||
if n_agents < 3:
|
||||
logger.warning(
|
||||
f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
|
||||
)
|
||||
selector.update_system_message(self.select_speaker_msg(agents))
|
||||
final, name = selector.generate_oai_reply(
|
||||
self.messages
|
||||
+ [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"Read the above conversation. Then select the next role from {self.agent_names} to play. Only return the role.",
|
||||
"content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.",
|
||||
}
|
||||
]
|
||||
)
|
||||
if not final:
|
||||
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
|
||||
return self.next_agent(last_speaker)
|
||||
return self.next_agent(last_speaker, agents)
|
||||
try:
|
||||
return self.agent_by_name(name)
|
||||
except ValueError:
|
||||
return self.next_agent(last_speaker)
|
||||
return self.next_agent(last_speaker, agents)
|
||||
|
||||
def _participant_roles(self):
|
||||
return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents])
|
||||
|
||||
@@ -1,6 +1,54 @@
|
||||
import pytest
|
||||
import autogen
|
||||
|
||||
|
||||
def test_func_call_groupchat():
|
||||
agent1 = autogen.ConversableAgent(
|
||||
"alice",
|
||||
human_input_mode="NEVER",
|
||||
llm_config=False,
|
||||
default_auto_reply="This is alice sepaking.",
|
||||
)
|
||||
agent2 = autogen.ConversableAgent(
|
||||
"bob",
|
||||
human_input_mode="NEVER",
|
||||
llm_config=False,
|
||||
default_auto_reply="This is bob speaking.",
|
||||
function_map={"test_func": lambda x: x},
|
||||
)
|
||||
groupchat = autogen.GroupChat(agents=[agent1, agent2], messages=[], max_round=3)
|
||||
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
|
||||
agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "test_func", "arguments": '{"x": 1}'}})
|
||||
|
||||
assert len(groupchat.messages) == 3
|
||||
assert (
|
||||
groupchat.messages[-2]["role"] == "function"
|
||||
and groupchat.messages[-2]["name"] == "test_func"
|
||||
and groupchat.messages[-2]["content"] == "1"
|
||||
)
|
||||
assert groupchat.messages[-1]["name"] == "alice"
|
||||
|
||||
agent3 = autogen.ConversableAgent(
|
||||
"carol",
|
||||
human_input_mode="NEVER",
|
||||
llm_config=False,
|
||||
default_auto_reply="This is carol speaking.",
|
||||
function_map={"test_func": lambda x: x + 1},
|
||||
)
|
||||
groupchat = autogen.GroupChat(agents=[agent1, agent2, agent3], messages=[], max_round=3)
|
||||
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
|
||||
agent3.initiate_chat(group_chat_manager, message={"function_call": {"name": "test_func", "arguments": '{"x": 1}'}})
|
||||
|
||||
assert (
|
||||
groupchat.messages[-2]["role"] == "function"
|
||||
and groupchat.messages[-2]["name"] == "test_func"
|
||||
and groupchat.messages[-2]["content"] == "1"
|
||||
)
|
||||
assert groupchat.messages[-1]["name"] == "carol"
|
||||
|
||||
agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}})
|
||||
|
||||
|
||||
def test_chat_manager():
|
||||
agent1 = autogen.ConversableAgent(
|
||||
"alice",
|
||||
@@ -30,6 +78,9 @@ def test_chat_manager():
|
||||
agent2.initiate_chat(group_chat_manager, message="hello")
|
||||
assert len(groupchat.messages) == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}})
|
||||
|
||||
|
||||
def test_plugin():
|
||||
# Give another Agent class ability to manage group chat
|
||||
@@ -62,6 +113,7 @@ def test_plugin():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_func_call_groupchat()
|
||||
# test_broadcast()
|
||||
# test_chat_manager()
|
||||
test_plugin()
|
||||
test_chat_manager()
|
||||
# test_plugin()
|
||||
|
||||
Reference in New Issue
Block a user