diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index 5fd4c1854..622a4d491 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -4,7 +4,7 @@ from typing import AsyncGenerator, List, Sequence from autogen_core.base import CancellationToken from ..base import ChatAgent, Response, TaskResult -from ..messages import AgentMessage, ChatMessage, MultiModalMessage, TextMessage +from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage class BaseChatAgent(ChatAgent, ABC): @@ -54,7 +54,7 @@ class BaseChatAgent(ChatAgent, ABC): async def run( self, *, - task: str | TextMessage | MultiModalMessage | None = None, + task: str | ChatMessage | None = None, cancellation_token: CancellationToken | None = None, ) -> TaskResult: """Run the agent with the given task and return the result.""" @@ -62,13 +62,17 @@ class BaseChatAgent(ChatAgent, ABC): cancellation_token = CancellationToken() input_messages: List[ChatMessage] = [] output_messages: List[AgentMessage] = [] - if isinstance(task, str): + if task is None: + pass + elif isinstance(task, str): text_msg = TextMessage(content=task, source="user") input_messages.append(text_msg) output_messages.append(text_msg) - elif isinstance(task, TextMessage | MultiModalMessage): + elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage): input_messages.append(task) output_messages.append(task) + else: + raise ValueError(f"Invalid task type: {type(task)}") response = await self.on_messages(input_messages, cancellation_token) if response.inner_messages is not None: output_messages += response.inner_messages @@ -78,7 +82,7 @@ class BaseChatAgent(ChatAgent, ABC): async def run_stream( self, *, - task: str | TextMessage | MultiModalMessage | None = None, + task: str | ChatMessage | None = None, cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentMessage | TaskResult, None]: """Run the agent with the given task and return a stream of messages @@ -87,15 +91,19 @@ class BaseChatAgent(ChatAgent, ABC): cancellation_token = CancellationToken() input_messages: List[ChatMessage] = [] output_messages: List[AgentMessage] = [] - if isinstance(task, str): + if task is None: + pass + elif isinstance(task, str): text_msg = TextMessage(content=task, source="user") input_messages.append(text_msg) output_messages.append(text_msg) yield text_msg - elif isinstance(task, TextMessage | MultiModalMessage): + elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage): input_messages.append(task) output_messages.append(task) yield task + else: + raise ValueError(f"Invalid task type: {type(task)}") async for message in self.on_messages_stream(input_messages, cancellation_token): if isinstance(message, Response): yield message.chat_message diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py index 0a5e37dce..d2cb39eb6 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py @@ -3,7 +3,7 @@ from typing import AsyncGenerator, Protocol, Sequence from autogen_core.base import CancellationToken -from ..messages import AgentMessage, MultiModalMessage, TextMessage +from ..messages import AgentMessage, ChatMessage @dataclass @@ -23,7 +23,7 @@ class TaskRunner(Protocol): async def run( self, *, - task: str | TextMessage | MultiModalMessage | None = None, + task: str | ChatMessage | None = None, cancellation_token: CancellationToken | None = None, ) -> TaskResult: """Run the task and return the result. @@ -36,7 +36,7 @@ class TaskRunner(Protocol): def run_stream( self, *, - task: str | TextMessage | MultiModalMessage | None = None, + task: str | ChatMessage | None = None, cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentMessage | TaskResult, None]: """Run the task and produces a stream of messages and the final result diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index c41ebaf46..fbca26449 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -12,13 +12,12 @@ from autogen_core.base import ( AgentType, CancellationToken, MessageContext, - TopicId, ) from autogen_core.components import ClosureAgent, TypeSubscription from ... import EVENT_LOGGER_NAME from ...base import ChatAgent, TaskResult, Team, TerminationCondition -from ...messages import AgentMessage, MultiModalMessage, TextMessage +from ...messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage from ._chat_agent_container import ChatAgentContainer from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination from ._sequential_routed_agent import SequentialRoutedAgent @@ -164,7 +163,7 @@ class BaseGroupChat(Team, ABC): async def run( self, *, - task: str | TextMessage | MultiModalMessage | None = None, + task: str | ChatMessage | None = None, cancellation_token: CancellationToken | None = None, ) -> TaskResult: """Run the team and return the result. The base implementation uses @@ -215,7 +214,7 @@ class BaseGroupChat(Team, ABC): async def run_stream( self, *, - task: str | TextMessage | MultiModalMessage | None = None, + task: str | ChatMessage | None = None, cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentMessage | TaskResult, None]: """Run the team and produces a stream of messages and the final result @@ -253,6 +252,16 @@ class BaseGroupChat(Team, ABC): asyncio.run(main()) """ + # Create the first chat message if the task is a string or a chat message. + first_chat_message: ChatMessage | None = None + if task is None: + pass + elif isinstance(task, str): + first_chat_message = TextMessage(content=task, source="user") + elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage): + first_chat_message = task + else: + raise ValueError(f"Invalid task type: {type(task)}") if self._is_running: raise ValueError("The team is already running, it cannot run again until it is stopped.") @@ -265,17 +274,6 @@ class BaseGroupChat(Team, ABC): if not self._initialized: await self._init(self._runtime) - # Run the team by publishing the start message. - first_chat_message: TextMessage | MultiModalMessage | None = None - if isinstance(task, str): - first_chat_message = TextMessage(content=task, source="user") - elif isinstance(task, TextMessage | MultiModalMessage): - first_chat_message = task - await self._runtime.publish_message( - GroupChatStart(message=first_chat_message), - topic_id=TopicId(type=self._group_topic_type, source=self._team_id), - ) - # Start a coroutine to stop the runtime and signal the output message queue is complete. async def stop_runtime() -> None: await self._runtime.stop_when_idle() @@ -283,24 +281,37 @@ class BaseGroupChat(Team, ABC): shutdown_task = asyncio.create_task(stop_runtime()) - # Collect the output messages in order. - output_messages: List[AgentMessage] = [] - # Yield the messsages until the queue is empty. - while True: - message = await self._output_message_queue.get() - if message is None: - break - yield message - output_messages.append(message) + try: + # Run the team by sending the start message to the group chat manager. + # The group chat manager will start the group chat by relaying the message to the participants + # and the closure agent. + await self._runtime.send_message( + GroupChatStart(message=first_chat_message), + recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id), + ) + # Collect the output messages in order. + output_messages: List[AgentMessage] = [] + # Yield the messsages until the queue is empty. + while True: + message = await self._output_message_queue.get() + if message is None: + break + yield message + output_messages.append(message) - # Wait for the shutdown task to finish. - await shutdown_task + # Yield the final result. + yield TaskResult(messages=output_messages, stop_reason=self._stop_reason) - # Yield the final result. - yield TaskResult(messages=output_messages, stop_reason=self._stop_reason) + finally: + # Wait for the shutdown task to finish. + await shutdown_task - # Indicate that the team is no longer running. - self._is_running = False + # Clear the output message queue. + while not self._output_message_queue.empty(): + self._output_message_queue.get_nowait() + + # Indicate that the team is no longer running. + self._is_running = False async def reset(self) -> None: """Reset the team and its participants to their initial state. @@ -352,19 +363,26 @@ class BaseGroupChat(Team, ABC): # Start the runtime. self._runtime.start() - # Send a reset message to the group chat. - await self._runtime.publish_message( - GroupChatReset(), - topic_id=TopicId(type=self._group_topic_type, source=self._team_id), - ) + try: + # Send a reset messages to all participants. + for participant_topic_type in self._participant_topic_types: + await self._runtime.send_message( + GroupChatReset(), + recipient=AgentId(type=participant_topic_type, key=self._team_id), + ) + # Send a reset message to the group chat manager. + await self._runtime.send_message( + GroupChatReset(), + recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id), + ) + finally: + # Stop the runtime. + await self._runtime.stop_when_idle() - # Stop the runtime. - await self._runtime.stop_when_idle() + # Reset the output message queue. + self._stop_reason = None + while not self._output_message_queue.empty(): + self._output_message_queue.get_nowait() - # Reset the output message queue. - self._stop_reason = None - while not self._output_message_queue.empty(): - self._output_message_queue.get_nowait() - - # Indicate that the team is no longer running. - self._is_running = False + # Indicate that the team is no longer running. + self._is_running = False diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index e28f46a8c..d2a2b9176 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -2,10 +2,10 @@ from abc import ABC, abstractmethod from typing import Any, List from autogen_core.base import MessageContext -from autogen_core.components import DefaultTopicId, event +from autogen_core.components import DefaultTopicId, event, rpc from ...base import TerminationCondition -from ...messages import AgentMessage, StopMessage +from ...messages import AgentMessage, ChatMessage, StopMessage from ._events import ( GroupChatAgentResponse, GroupChatRequestPublish, @@ -55,7 +55,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): self._max_turns = max_turns self._current_turn = 0 - @event + @rpc async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None: """Handle the start of a group chat by selecting a speaker to start the conversation.""" @@ -70,10 +70,16 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): # Stop the group chat. return + # Validate the group state given the start message. + await self.validate_group_state(message.message) + if message.message is not None: # Log the start message. await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type)) + # Relay the start message to the participants. + await self.publish_message(message, topic_id=DefaultTopicId(type=self._group_topic_type)) + # Append the user message to the message thread. self._message_thread.append(message.message) @@ -137,11 +143,16 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): speaker_topic_type = await self.select_speaker(self._message_thread) await self.publish_message(GroupChatRequestPublish(), topic_id=DefaultTopicId(type=speaker_topic_type)) - @event + @rpc async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None: # Reset the group chat manager. await self.reset() + @abstractmethod + async def validate_group_state(self, message: ChatMessage | None) -> None: + """Validate the state of the group chat given the start message. This is executed when the group chat manager receives a GroupChatStart event.""" + ... + @abstractmethod async def select_speaker(self, thread: List[AgentMessage]) -> str: """Select a speaker from the participants and return the diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py index d249676fd..315708032 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -1,7 +1,7 @@ from typing import Any, List from autogen_core.base import MessageContext -from autogen_core.components import DefaultTopicId, event +from autogen_core.components import DefaultTopicId, event, rpc from ...base import ChatAgent, Response from ...messages import ChatMessage @@ -38,7 +38,7 @@ class ChatAgentContainer(SequentialRoutedAgent): """Handle an agent response event by appending the content to the buffer.""" self._message_buffer.append(message.agent_response.chat_message) - @event + @rpc async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None: """Handle a reset event by resetting the agent.""" self._message_buffer.clear() diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index 5410673eb..ae1567a7d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -2,7 +2,7 @@ import json from typing import Any, List from autogen_core.base import MessageContext -from autogen_core.components import DefaultTopicId, Image, event +from autogen_core.components import DefaultTopicId, Image, event, rpc from autogen_core.components.models import ( AssistantMessage, ChatCompletionClient, @@ -102,7 +102,7 @@ class MagenticOneOrchestrator(SequentialRoutedAgent): def _get_final_answer_prompt(self, task: str) -> str: return ORCHESTRATOR_FINAL_ANSWER_PROMPT.format(task=task) - @event + @rpc async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None: """Handle the start of a group chat by selecting a speaker to start the conversation.""" assert message is not None and message.message is not None @@ -145,7 +145,7 @@ class MagenticOneOrchestrator(SequentialRoutedAgent): self._message_thread.append(message.agent_response.chat_message) await self._orchestrate_step() - @event + @rpc async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None: # Reset the group chat manager. await self.reset() diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py index 8e3a262a8..f5c128a6c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py @@ -1,7 +1,7 @@ from typing import Callable, List from ...base import ChatAgent, TerminationCondition -from ...messages import AgentMessage +from ...messages import AgentMessage, ChatMessage from ._base_group_chat import BaseGroupChat from ._base_group_chat_manager import BaseGroupChatManager @@ -28,6 +28,9 @@ class RoundRobinGroupChatManager(BaseGroupChatManager): ) self._next_speaker_index = 0 + async def validate_group_state(self, message: ChatMessage | None) -> None: + pass + async def reset(self) -> None: self._current_turn = 0 self._message_thread.clear() @@ -68,7 +71,7 @@ class RoundRobinGroupChat(BaseGroupChat): from autogen_ext.models import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import RoundRobinGroupChat - from autogen_agentchat.task import TextMentionTermination + from autogen_agentchat.task import TextMentionTermination, Console async def main() -> None: @@ -84,9 +87,7 @@ class RoundRobinGroupChat(BaseGroupChat): ) termination = TextMentionTermination("TERMINATE") team = RoundRobinGroupChat([assistant], termination_condition=termination) - stream = team.run_stream("What's the weather in New York?") - async for message in stream: - print(message) + await Console(team.run_stream(task="What's the weather in New York?")) asyncio.run(main()) @@ -99,7 +100,7 @@ class RoundRobinGroupChat(BaseGroupChat): from autogen_ext.models import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import RoundRobinGroupChat - from autogen_agentchat.task import TextMentionTermination + from autogen_agentchat.task import TextMentionTermination, Console async def main() -> None: @@ -109,9 +110,7 @@ class RoundRobinGroupChat(BaseGroupChat): agent2 = AssistantAgent("Assistant2", model_client=model_client) termination = TextMentionTermination("TERMINATE") team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) - stream = team.run_stream("Tell me some jokes.") - async for message in stream: - print(message) + await Console(team.run_stream(task="Tell me some jokes.")) asyncio.run(main()) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index f618c0d38..e21e99e0f 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -8,6 +8,7 @@ from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME from ...base import ChatAgent, TerminationCondition from ...messages import ( AgentMessage, + ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, @@ -53,6 +54,9 @@ class SelectorGroupChatManager(BaseGroupChatManager): self._allow_repeated_speaker = allow_repeated_speaker self._selector_func = selector_func + async def validate_group_state(self, message: ChatMessage | None) -> None: + pass + async def reset(self) -> None: self._current_turn = 0 self._message_thread.clear() @@ -204,7 +208,7 @@ class SelectorGroupChat(BaseGroupChat): from autogen_ext.models import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import SelectorGroupChat - from autogen_agentchat.task import TextMentionTermination + from autogen_agentchat.task import TextMentionTermination, Console async def main() -> None: @@ -243,9 +247,7 @@ class SelectorGroupChat(BaseGroupChat): model_client=model_client, termination_condition=termination, ) - stream = team.run_stream("Book a 3-day trip to new york.") - async for message in stream: - print(message) + await Console(team.run_stream(task="Book a 3-day trip to new york.")) asyncio.run(main()) @@ -258,7 +260,7 @@ class SelectorGroupChat(BaseGroupChat): from autogen_ext.models import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import SelectorGroupChat - from autogen_agentchat.task import TextMentionTermination + from autogen_agentchat.task import TextMentionTermination, Console async def main() -> None: @@ -299,9 +301,7 @@ class SelectorGroupChat(BaseGroupChat): termination_condition=termination, ) - stream = team.run_stream("What is 1 + 1?") - async for message in stream: - print(message) + await Console(team.run_stream(task="What is 1 + 1?")) asyncio.run(main()) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py index 651367169..0e658ab75 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -3,7 +3,7 @@ from typing import Callable, List from ... import EVENT_LOGGER_NAME from ...base import ChatAgent, TerminationCondition -from ...messages import AgentMessage, HandoffMessage +from ...messages import AgentMessage, ChatMessage, HandoffMessage from ._base_group_chat import BaseGroupChat from ._base_group_chat_manager import BaseGroupChatManager @@ -32,6 +32,31 @@ class SwarmGroupChatManager(BaseGroupChatManager): ) self._current_speaker = participant_topic_types[0] + async def validate_group_state(self, message: ChatMessage | None) -> None: + """Validate the start message for the group chat.""" + # Check if the start message is a handoff message. + if isinstance(message, HandoffMessage): + if message.target not in self._participant_topic_types: + raise ValueError( + f"The target {message.target} is not one of the participants {self._participant_topic_types}. " + "If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target." + ) + return + # Check if there is a handoff message in the thread that is not targeting a valid participant. + for existing_message in reversed(self._message_thread): + if isinstance(existing_message, HandoffMessage): + if existing_message.target not in self._participant_topic_types: + raise ValueError( + f"The existing handoff target {existing_message.target} is not one of the participants {self._participant_topic_types}. " + "If you are resuming Swarm with a new task make sure to include in your task " + "a HandoffMessage with a valid participant as the target. For example, if you are " + "resuming from a HandoffTermination, make sure the new task is a HandoffMessage " + "with a valid participant as the target." + ) + # The latest handoff message should always target a valid participant. + # Do not look past the latest handoff message. + return + async def reset(self) -> None: self._current_turn = 0 self._message_thread.clear() @@ -47,13 +72,8 @@ class SwarmGroupChatManager(BaseGroupChatManager): for message in reversed(thread): if isinstance(message, HandoffMessage): self._current_speaker = message.target - if self._current_speaker not in self._participant_topic_types: - raise ValueError( - f"The target {self._current_speaker} in the handoff message " - f"is not one of the participants {self._participant_topic_types}. " - "If you are resuming the Swarm with a new task make sure to include in your task " - "a handoff message with a valid participant as the target." - ) + # The latest handoff message should always target a valid participant. + assert self._current_speaker in self._participant_topic_types return self._current_speaker return self._current_speaker @@ -72,7 +92,7 @@ class Swarm(BaseGroupChat): Without a termination condition, the group chat will run indefinitely. max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit. - Examples: + Basic example: .. code-block:: python @@ -99,11 +119,49 @@ class Swarm(BaseGroupChat): termination = MaxMessageTermination(3) team = Swarm([agent1, agent2], termination_condition=termination) - stream = team.run_stream("What is bob's birthday?") + stream = team.run_stream(task="What is bob's birthday?") async for message in stream: print(message) + asyncio.run(main()) + + + Using the :class:`~autogen_agentchat.task.HandoffTermination` for human-in-the-loop handoff: + + .. code-block:: python + + import asyncio + from autogen_ext.models import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.teams import Swarm + from autogen_agentchat.task import HandoffTermination, Console, MaxMessageTermination + from autogen_agentchat.messages import HandoffMessage + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + + agent = AssistantAgent( + "Alice", + model_client=model_client, + handoffs=["user"], + system_message="You are Alice and you only answer questions about yourself, ask the user for help if needed.", + ) + termination = HandoffTermination(target="user") | MaxMessageTermination(3) + team = Swarm([agent], termination_condition=termination) + + # Start the conversation. + await Console(team.run_stream(task="What is bob's birthday?")) + + # Resume with user feedback. + await Console( + team.run_stream( + task=HandoffMessage(source="user", target="Alice", content="Bob's birthday is on 1st January.") + ) + ) + + asyncio.run(main()) """ diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index db8bfa9d4..b6c4f4bfd 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -24,7 +24,7 @@ from autogen_agentchat.messages import ( ToolCallMessage, ToolCallResultMessage, ) -from autogen_agentchat.task import MaxMessageTermination, TextMentionTermination +from autogen_agentchat.task import HandoffTermination, MaxMessageTermination, TextMentionTermination from autogen_agentchat.teams import ( RoundRobinGroupChat, SelectorGroupChat, @@ -815,3 +815,56 @@ async def test_swarm_pause_and_resume() -> None: result = await team.run() assert len(result.messages) == 1 assert result.messages[0].content == "Transferred to second_agent." + + +@pytest.mark.asyncio +async def test_swarm_with_handoff_termination() -> None: + first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent") + second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent") + third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent") + + # Handoff to an existing agent. + termination = HandoffTermination(target="third_agent") + team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination) + # Start + result = await team.run(task="task") + assert len(result.messages) == 2 + assert result.messages[0].content == "task" + assert result.messages[1].content == "Transferred to third_agent." + # Resume existing. + result = await team.run() + assert len(result.messages) == 3 + assert result.messages[0].content == "Transferred to first_agent." + assert result.messages[1].content == "Transferred to second_agent." + assert result.messages[2].content == "Transferred to third_agent." + # Resume new task. + result = await team.run(task="new task") + assert len(result.messages) == 4 + assert result.messages[0].content == "new task" + assert result.messages[1].content == "Transferred to first_agent." + assert result.messages[2].content == "Transferred to second_agent." + assert result.messages[3].content == "Transferred to third_agent." + + # Handoff to a non-existing agent. + third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="non_existing_agent") + termination = HandoffTermination(target="non_existing_agent") + team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination) + # Start + result = await team.run(task="task") + assert len(result.messages) == 3 + assert result.messages[0].content == "task" + assert result.messages[1].content == "Transferred to third_agent." + assert result.messages[2].content == "Transferred to non_existing_agent." + # Attempt to resume. + with pytest.raises(ValueError): + await team.run() + # Attempt to resume with a new task. + with pytest.raises(ValueError): + await team.run(task="new task") + # Resume with a HandoffMessage + result = await team.run(task=HandoffMessage(content="Handoff to first_agent.", target="first_agent", source="user")) + assert len(result.messages) == 4 + assert result.messages[0].content == "Handoff to first_agent." + assert result.messages[1].content == "Transferred to second_agent." + assert result.messages[2].content == "Transferred to third_agent." + assert result.messages[3].content == "Transferred to non_existing_agent." diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/teams.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/teams.ipynb index 0b5af2088..49340ef11 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/teams.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/teams.ipynb @@ -764,16 +764,6 @@ "# Use `asyncio.run(Console(lazy_agent_team.run_stream(task=\"It is raining in New York.\")))` when running in a script.\n", "await Console(lazy_agent_team.run_stream(task=\"It is raining in New York.\"))" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```{note}\n", - "Currently the handoff termination approach does not work with {py:class}`~autogen_agentchat.teams.Swarm`.\n", - "Please stay tuned for the updates.\n", - "```" - ] } ], "metadata": { @@ -792,7 +782,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.6" + "version": "3.11.5" } }, "nbformat": 4,