Fix-swarm-handoff (#4198)

* fix select speaker for swarm

* Fix max-turn = 1 for swarm
This commit is contained in:
Eric Zhu
2024-11-15 10:02:59 -08:00
committed by GitHub
parent c6d69ab4c1
commit 78019dd2dc
4 changed files with 50 additions and 8 deletions

View File

@@ -40,14 +40,22 @@ class SwarmGroupChatManager(BaseGroupChatManager):
self._current_speaker = self._participant_topic_types[0]
async def select_speaker(self, thread: List[AgentMessage]) -> str:
"""Select a speaker from the participants based on handoff message."""
if len(thread) > 0 and isinstance(thread[-1], HandoffMessage):
self._current_speaker = thread[-1].target
if self._current_speaker not in self._participant_topic_types:
raise ValueError("The selected speaker in the handoff message is not a participant.")
return self._current_speaker
else:
"""Select a speaker from the participants based on handoff message.
Looks for the last handoff message in the thread to determine the next speaker."""
if len(thread) == 0:
return self._current_speaker
for message in reversed(thread):
if isinstance(message, HandoffMessage):
self._current_speaker = message.target
if self._current_speaker not in self._participant_topic_types:
raise ValueError(
f"The target {self._current_speaker} in the handoff message "
f"is not one of the participants {self._participant_topic_types}. "
"If you are resuming the Swarm with a new task make sure to include in your task "
"a handoff message with a valid participant as the target."
)
return self._current_speaker
return self._current_speaker
class Swarm(BaseGroupChat):

View File

@@ -791,3 +791,27 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
else:
assert message == result.messages[index]
index += 1
@pytest.mark.asyncio
async def test_swarm_pause_and_resume() -> None:
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
team = Swarm([second_agent, first_agent, third_agent], max_turns=1)
result = await team.run(task="task")
assert len(result.messages) == 2
assert result.messages[0].content == "task"
assert result.messages[1].content == "Transferred to third_agent."
# Resume with a new task.
result = await team.run(task="new task")
assert len(result.messages) == 2
assert result.messages[0].content == "new task"
assert result.messages[1].content == "Transferred to first_agent."
# Resume with the same task.
result = await team.run()
assert len(result.messages) == 1
assert result.messages[0].content == "Transferred to second_agent."

View File

@@ -281,7 +281,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.6"
}
},
"nbformat": 4,

View File

@@ -764,6 +764,16 @@
"# Use `asyncio.run(Console(lazy_agent_team.run_stream(task=\"It is raining in New York.\")))` when running in a script.\n",
"await Console(lazy_agent_team.run_stream(task=\"It is raining in New York.\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{note}\n",
"Currently the handoff termination approach does not work with {py:class}`~autogen_agentchat.teams.Swarm`.\n",
"Please stay tuned for the updates.\n",
"```"
]
}
],
"metadata": {