mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-27 22:07:55 -05:00
Partial fix for 960 (#963)
* Partial fix for 960 * Fixed a missing = None * Added test coverage.
This commit is contained in:
@@ -64,8 +64,10 @@ class GroupChat:
|
||||
"""Returns the agent with a given name."""
|
||||
return self.agents[self.agent_names.index(name)]
|
||||
|
||||
def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
|
||||
def next_agent(self, agent: Agent, agents: Optional[List[Agent]] = None) -> Agent:
|
||||
"""Return the next agent in the list."""
|
||||
if agents is None:
|
||||
agents = self.agents
|
||||
|
||||
# What index is the agent? (-1 if not present)
|
||||
idx = self.agent_names.index(agent.name) if agent.name in self.agent_names else -1
|
||||
@@ -79,20 +81,26 @@ class GroupChat:
|
||||
if self.agents[(offset + i) % len(self.agents)] in agents:
|
||||
return self.agents[(offset + i) % len(self.agents)]
|
||||
|
||||
def select_speaker_msg(self, agents: List[Agent]) -> str:
|
||||
def select_speaker_msg(self, agents: Optional[List[Agent]] = None) -> str:
|
||||
"""Return the system message for selecting the next speaker. This is always the *first* message in the context."""
|
||||
if agents is None:
|
||||
agents = self.agents
|
||||
return f"""You are in a role play game. The following roles are available:
|
||||
{self._participant_roles(agents)}.
|
||||
|
||||
Read the following conversation.
|
||||
Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""
|
||||
|
||||
def select_speaker_prompt(self, agents: List[Agent]) -> str:
|
||||
def select_speaker_prompt(self, agents: Optional[List[Agent]] = None) -> str:
|
||||
"""Return the floating system prompt selecting the next speaker. This is always the *last* message in the context."""
|
||||
if agents is None:
|
||||
agents = self.agents
|
||||
return f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."
|
||||
|
||||
def manual_select_speaker(self, agents: List[Agent]) -> Union[Agent, None]:
|
||||
def manual_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]:
|
||||
"""Manually select the next speaker."""
|
||||
if agents is None:
|
||||
agents = self.agents
|
||||
|
||||
print("Please select the next speaker from the following list:")
|
||||
_n_agents = len(agents)
|
||||
|
||||
@@ -421,6 +421,10 @@ def test_next_agent():
|
||||
assert groupchat.next_agent(agent2, [agent1, agent2, agent3]) == agent3
|
||||
assert groupchat.next_agent(agent3, [agent1, agent2, agent3]) == agent1
|
||||
|
||||
assert groupchat.next_agent(agent1) == agent2
|
||||
assert groupchat.next_agent(agent2) == agent3
|
||||
assert groupchat.next_agent(agent3) == agent1
|
||||
|
||||
assert groupchat.next_agent(agent1, [agent1, agent3]) == agent3
|
||||
assert groupchat.next_agent(agent3, [agent1, agent3]) == agent1
|
||||
|
||||
@@ -429,6 +433,48 @@ def test_next_agent():
|
||||
assert groupchat.next_agent(agent4, [agent1, agent2, agent3]) == agent1
|
||||
|
||||
|
||||
def test_selection_helpers():
|
||||
agent1 = autogen.ConversableAgent(
|
||||
"alice",
|
||||
max_consecutive_auto_reply=10,
|
||||
human_input_mode="NEVER",
|
||||
llm_config=False,
|
||||
default_auto_reply="This is alice speaking.",
|
||||
description="Alice is an AI agent.",
|
||||
)
|
||||
agent2 = autogen.ConversableAgent(
|
||||
"bob",
|
||||
max_consecutive_auto_reply=10,
|
||||
human_input_mode="NEVER",
|
||||
llm_config=False,
|
||||
description="Bob is an AI agent.",
|
||||
)
|
||||
agent3 = autogen.ConversableAgent(
|
||||
"sam",
|
||||
max_consecutive_auto_reply=10,
|
||||
human_input_mode="NEVER",
|
||||
llm_config=False,
|
||||
default_auto_reply="This is sam speaking.",
|
||||
system_message="Sam is an AI agent.",
|
||||
)
|
||||
|
||||
# Test empty is_termination_msg function
|
||||
groupchat = autogen.GroupChat(
|
||||
agents=[agent1, agent2, agent3], messages=[], speaker_selection_method="round_robin", max_round=10
|
||||
)
|
||||
|
||||
select_speaker_msg = groupchat.select_speaker_msg()
|
||||
select_speaker_prompt = groupchat.select_speaker_prompt()
|
||||
|
||||
assert "Alice is an AI agent." in select_speaker_msg
|
||||
assert "Bob is an AI agent." in select_speaker_msg
|
||||
assert "Sam is an AI agent." in select_speaker_msg
|
||||
assert str(["Alice", "Bob", "Sam"]).lower() in select_speaker_prompt.lower()
|
||||
|
||||
with mock.patch.object(builtins, "input", lambda _: "1"):
|
||||
groupchat.manual_select_speaker()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_func_call_groupchat()
|
||||
# test_broadcast()
|
||||
|
||||
Reference in New Issue
Block a user