From 78019dd2dc177a2f74685e7cb088e5b408a98934 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Fri, 15 Nov 2024 10:02:59 -0800 Subject: [PATCH] Fix-swarm-handoff (#4198) * fix select speaker for swarm * Fix max-turn = 1 for swarm --- .../teams/_group_chat/_swarm_group_chat.py | 22 +++++++++++------ .../tests/test_group_chat.py | 24 +++++++++++++++++++ .../tutorial/selector-group-chat.ipynb | 2 +- .../agentchat-user-guide/tutorial/teams.ipynb | 10 ++++++++ 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py index d37684feb..651367169 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -40,14 +40,22 @@ class SwarmGroupChatManager(BaseGroupChatManager): self._current_speaker = self._participant_topic_types[0] async def select_speaker(self, thread: List[AgentMessage]) -> str: - """Select a speaker from the participants based on handoff message.""" - if len(thread) > 0 and isinstance(thread[-1], HandoffMessage): - self._current_speaker = thread[-1].target - if self._current_speaker not in self._participant_topic_types: - raise ValueError("The selected speaker in the handoff message is not a participant.") - return self._current_speaker - else: + """Select a speaker from the participants based on handoff message. + Looks for the last handoff message in the thread to determine the next speaker.""" + if len(thread) == 0: return self._current_speaker + for message in reversed(thread): + if isinstance(message, HandoffMessage): + self._current_speaker = message.target + if self._current_speaker not in self._participant_topic_types: + raise ValueError( + f"The target {self._current_speaker} in the handoff message " + f"is not one of the participants {self._participant_topic_types}. " + "If you are resuming the Swarm with a new task make sure to include in your task " + "a handoff message with a valid participant as the target." + ) + return self._current_speaker + return self._current_speaker class Swarm(BaseGroupChat): diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 11eb2c925..db8bfa9d4 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -791,3 +791,27 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) - else: assert message == result.messages[index] index += 1 + + +@pytest.mark.asyncio +async def test_swarm_pause_and_resume() -> None: + first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent") + second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent") + third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent") + + team = Swarm([second_agent, first_agent, third_agent], max_turns=1) + result = await team.run(task="task") + assert len(result.messages) == 2 + assert result.messages[0].content == "task" + assert result.messages[1].content == "Transferred to third_agent." + + # Resume with a new task. + result = await team.run(task="new task") + assert len(result.messages) == 2 + assert result.messages[0].content == "new task" + assert result.messages[1].content == "Transferred to first_agent." + + # Resume with the same task. + result = await team.run() + assert len(result.messages) == 1 + assert result.messages[0].content == "Transferred to second_agent." diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/selector-group-chat.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/selector-group-chat.ipynb index 5124ca2e6..0a3bb0fe3 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/selector-group-chat.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/selector-group-chat.ipynb @@ -281,7 +281,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.12.6" } }, "nbformat": 4, diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/teams.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/teams.ipynb index b7899768a..8d001f897 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/teams.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/teams.ipynb @@ -764,6 +764,16 @@ "# Use `asyncio.run(Console(lazy_agent_team.run_stream(task=\"It is raining in New York.\")))` when running in a script.\n", "await Console(lazy_agent_team.run_stream(task=\"It is raining in New York.\"))" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{note}\n", + "Currently the handoff termination approach does not work with {py:class}`~autogen_agentchat.teams.Swarm`.\n", + "Please stay tuned for the updates.\n", + "```" + ] } ], "metadata": {