mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Fix-swarm-handoff (#4198)
* fix select speaker for swarm * Fix max-turn = 1 for swarm
This commit is contained in:
@@ -40,14 +40,22 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
self._current_speaker = self._participant_topic_types[0]
|
||||
|
||||
async def select_speaker(self, thread: List[AgentMessage]) -> str:
|
||||
"""Select a speaker from the participants based on handoff message."""
|
||||
if len(thread) > 0 and isinstance(thread[-1], HandoffMessage):
|
||||
self._current_speaker = thread[-1].target
|
||||
if self._current_speaker not in self._participant_topic_types:
|
||||
raise ValueError("The selected speaker in the handoff message is not a participant.")
|
||||
return self._current_speaker
|
||||
else:
|
||||
"""Select a speaker from the participants based on handoff message.
|
||||
Looks for the last handoff message in the thread to determine the next speaker."""
|
||||
if len(thread) == 0:
|
||||
return self._current_speaker
|
||||
for message in reversed(thread):
|
||||
if isinstance(message, HandoffMessage):
|
||||
self._current_speaker = message.target
|
||||
if self._current_speaker not in self._participant_topic_types:
|
||||
raise ValueError(
|
||||
f"The target {self._current_speaker} in the handoff message "
|
||||
f"is not one of the participants {self._participant_topic_types}. "
|
||||
"If you are resuming the Swarm with a new task make sure to include in your task "
|
||||
"a handoff message with a valid participant as the target."
|
||||
)
|
||||
return self._current_speaker
|
||||
return self._current_speaker
|
||||
|
||||
|
||||
class Swarm(BaseGroupChat):
|
||||
|
||||
@@ -791,3 +791,27 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swarm_pause_and_resume() -> None:
|
||||
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
||||
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
||||
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
||||
|
||||
team = Swarm([second_agent, first_agent, third_agent], max_turns=1)
|
||||
result = await team.run(task="task")
|
||||
assert len(result.messages) == 2
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "Transferred to third_agent."
|
||||
|
||||
# Resume with a new task.
|
||||
result = await team.run(task="new task")
|
||||
assert len(result.messages) == 2
|
||||
assert result.messages[0].content == "new task"
|
||||
assert result.messages[1].content == "Transferred to first_agent."
|
||||
|
||||
# Resume with the same task.
|
||||
result = await team.run()
|
||||
assert len(result.messages) == 1
|
||||
assert result.messages[0].content == "Transferred to second_agent."
|
||||
|
||||
@@ -281,7 +281,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
"version": "3.12.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -764,6 +764,16 @@
|
||||
"# Use `asyncio.run(Console(lazy_agent_team.run_stream(task=\"It is raining in New York.\")))` when running in a script.\n",
|
||||
"await Console(lazy_agent_team.run_stream(task=\"It is raining in New York.\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```{note}\n",
|
||||
"Currently the handoff termination approach does not work with {py:class}`~autogen_agentchat.teams.Swarm`.\n",
|
||||
"Please stay tuned for the updates.\n",
|
||||
"```"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Reference in New Issue
Block a user