diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index 1ebe658c18..6c008bef66 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -2,7 +2,7 @@ import asyncio from abc import ABC, abstractmethod from typing import Any, List, Sequence -from autogen_core import DefaultTopicId, MessageContext, event, rpc +from autogen_core import CancellationToken, DefaultTopicId, MessageContext, event, rpc from ...base import TerminationCondition from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, SelectSpeakerEvent, StopMessage @@ -79,6 +79,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): self._current_turn = 0 self._message_factory = message_factory self._emit_team_events = emit_team_events + self._active_speakers: List[str] = [] @rpc async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None: @@ -122,58 +123,35 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): # Stop the group chat. return - # Select a speaker to start/continue the conversation - speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) - # Link the select speaker future to the cancellation token. - ctx.cancellation_token.link_future(speaker_name_future) - speaker_name = await speaker_name_future - if speaker_name not in self._participant_name_to_topic_type: - raise RuntimeError(f"Speaker {speaker_name} not found in participant names.") - await self._log_speaker_selection(speaker_name) - - # Send the message to the next speaker - speaker_topic_type = self._participant_name_to_topic_type[speaker_name] - await self.publish_message( - GroupChatRequestPublish(), - topic_id=DefaultTopicId(type=speaker_topic_type), - cancellation_token=ctx.cancellation_token, - ) - - async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: - self._message_thread.extend(messages) + # Select speakers to start/continue the conversation + await self._transition_to_next_speakers(ctx.cancellation_token) @event async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: try: - # Append the message to the message thread and construct the delta. + # Construct the detla from the agent response. delta: List[BaseAgentEvent | BaseChatMessage] = [] if message.agent_response.inner_messages is not None: for inner_message in message.agent_response.inner_messages: delta.append(inner_message) delta.append(message.agent_response.chat_message) + + # Append the messages to the message thread. await self.update_message_thread(delta) + # Remove the agent from the active speakers list. + self._active_speakers.remove(message.agent_name) + if len(self._active_speakers) > 0: + # If there are still active speakers, return without doing anything. + return + # Check if the conversation should be terminated. if await self._apply_termination_condition(delta, increment_turn_count=True): # Stop the group chat. return - # Select a speaker to continue the conversation. - speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) - # Link the select speaker future to the cancellation token. - ctx.cancellation_token.link_future(speaker_name_future) - speaker_name = await speaker_name_future - if speaker_name not in self._participant_name_to_topic_type: - raise RuntimeError(f"Speaker {speaker_name} not found in participant names.") - await self._log_speaker_selection(speaker_name) - - # Send the message to the next speakers - speaker_topic_type = self._participant_name_to_topic_type[speaker_name] - await self.publish_message( - GroupChatRequestPublish(), - topic_id=DefaultTopicId(type=speaker_topic_type), - cancellation_token=ctx.cancellation_token, - ) + # Select speakers to continue the conversation. + await self._transition_to_next_speakers(ctx.cancellation_token) except Exception as e: # Handle the exception and signal termination with an error. error = SerializableException.from_exception(e) @@ -181,6 +159,29 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): # Raise the exception to the runtime. raise + async def _transition_to_next_speakers(self, cancellation_token: CancellationToken) -> None: + speaker_names_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) + # Link the select speaker future to the cancellation token. + cancellation_token.link_future(speaker_names_future) + speaker_names = await speaker_names_future + if isinstance(speaker_names, str): + # If only one speaker is selected, convert it to a list. + speaker_names = [speaker_names] + for speaker_name in speaker_names: + if speaker_name not in self._participant_name_to_topic_type: + raise RuntimeError(f"Speaker {speaker_name} not found in participant names.") + await self._log_speaker_selection(speaker_names) + + # Send request to publish message to the next speakers + for speaker_name in speaker_names: + speaker_topic_type = self._participant_name_to_topic_type[speaker_name] + await self.publish_message( + GroupChatRequestPublish(), + topic_id=DefaultTopicId(type=speaker_topic_type), + cancellation_token=cancellation_token, + ) + self._active_speakers.append(speaker_name) + async def _apply_termination_condition( self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False ) -> bool: @@ -216,9 +217,9 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): return True return False - async def _log_speaker_selection(self, speaker_name: str) -> None: + async def _log_speaker_selection(self, speaker_names: List[str]) -> None: """Log the selected speaker to the output message queue.""" - select_msg = SelectSpeakerEvent(content=[speaker_name], source=self._name) + select_msg = SelectSpeakerEvent(content=speaker_names, source=self._name) if self._emit_team_events: await self.publish_message( GroupChatMessage(message=select_msg), @@ -284,10 +285,26 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): """ ... + async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: + """Update the message thread with the new messages. + This is called when the group chat receives a GroupChatStart or GroupChatAgentResponse event, + before calling the select_speakers method. + """ + self._message_thread.extend(messages) + @abstractmethod - async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str: - """Select a speaker from the participants and return the - topic type of the selected speaker.""" + async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str: + """Select speakers from the participants and return the topic types of the selected speaker. + This is called when the group chat manager have received all responses from the participants + for a turn and is ready to select the next speakers for the next turn. + + Args: + thread: The message thread of the group chat. + + Returns: + A list of topic types of the selected speakers. + If only one speaker is selected, a single string is returned instead of a list. + """ ... @abstractmethod diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py index 69faeb4917..f9ec8636f6 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -89,7 +89,7 @@ class ChatAgentContainer(SequentialRoutedAgent): # Publish the response to the group chat. self._message_buffer.clear() await self.publish_message( - GroupChatAgentResponse(agent_response=response), + GroupChatAgentResponse(agent_response=response, agent_name=self._agent.name), topic_id=DefaultTopicId(type=self._parent_topic_type), cancellation_token=ctx.cancellation_token, ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py index ca07d87bbe..febb69e603 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py @@ -48,6 +48,9 @@ class GroupChatAgentResponse(BaseModel): agent_response: Response """The response from an agent.""" + agent_name: str + """The name of the agent that produced the response.""" + class GroupChatRequestPublish(BaseModel): """A request to publish a message to a group chat.""" diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py index 87b083b3de..94d133b201 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py @@ -1,5 +1,6 @@ import asyncio -from typing import Any, Callable, Dict, List, Literal, Mapping, Sequence, Set +from collections import Counter, deque +from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel from pydantic import BaseModel @@ -208,154 +209,70 @@ class GraphFlowManager(BaseGroupChatManager): max_turns=max_turns, message_factory=message_factory, ) - self._graph = graph - self._graph.graph_validate() - self._graph_has_cycles = self._graph.get_has_cycles() - if self._graph_has_cycles and self._termination_condition is None and self._max_turns is None: + graph.graph_validate() + if graph.get_has_cycles() and self._termination_condition is None and self._max_turns is None: raise ValueError("A termination condition is required for cyclic graphs without a maximum turn limit.") + self._graph = graph + # Lookup table for incoming edges for each node. + self._parents = graph.get_parents() + # Lookup table for outgoing edges for each node. + self._edges: Dict[str, List[DiGraphEdge]] = {n: node.edges for n, node in graph.nodes.items()} + # Activation lookup table for each node. + self._activation: Dict[str, Literal["any", "all"]] = {n: node.activation for n, node in graph.nodes.items()} - self._use_default_start = self._graph.default_start_node is not None - self._default_start_executed = False - self._start_nodes = graph.get_start_nodes() - self._leaf_nodes = graph.get_leaf_nodes() - self._parents = graph.get_parents() # Parent node dependencies - helper dict to get all incoming edges - self._active_nodes: Set[str] = set() # Currently executing nodes - self._active_node_count: Dict[str, int] = { - node: 0 for node in graph.nodes - } # Number of times a node has been active + # === Mutable states for the graph execution === + # Count the number of remaining parents to activate each node. + self._remaining: Counter[str] = Counter({n: len(p) for n, p in self._parents.items()}) + # Lookup table for nodes that have been enqueued through an any activation. + # This is used to prevent re-adding the same node multiple times. + self._enqueued_any: Dict[str, bool] = {n: False for n in graph.nodes} + # Ready queue for nodes that are ready to execute, starting with the start nodes. + self._ready: Deque[str] = deque([n for n in graph.get_start_nodes()]) - # These are nodes next in line for execution as one or more of their parent nodes have started execution. - # They execute when all their parent nodes have executed. - # Nodes are added to this dict when at least one of their parent nodes becomes active. - # Start nodes (no parents) are added to this dict at initialization as they are always ready to run. - self._pending_execution: Dict[str, List[str]] = {node: [] for node in graph.get_start_nodes()} + async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: + await super().update_message_thread(messages) - def _get_valid_target(self, node: DiGraphNode, content: str) -> str: - """Check if a condition is met in the chat history.""" - for edge in node.edges: - if edge.condition and edge.condition in content: - return edge.target + # Find the node that ran in the current turn. + message = messages[-1] + if message.source not in self._graph.nodes: + # Ignore messages from sources outside of the graph. + return + assert isinstance(message, BaseChatMessage) + source = message.source + content = message.to_model_text() - raise RuntimeError(f"Condition not met for node {node.name}. Content: {content}") - - def _is_node_ready(self, node_name: str) -> bool: - """Check if a node is ready to execute based on its parent nodes. - If activation is any then execute as soon as any parent has finished - If activation is all then execute only when all parents have finished - """ - node = self._graph.nodes[node_name] - if node.activation == "any": - return bool(self._pending_execution[node_name]) - return all(parent in self._pending_execution[node_name] for parent in self._parents[node_name]) - - async def _select_speakers(self, thread: List[BaseAgentEvent | BaseChatMessage], many: bool = True) -> List[str]: - """Select the next set of agents to execute based on DAG constraints.""" - next_speakers: Set[str] = set() - source_node: DiGraphNode | None = None - source: str | None = None - - if thread and isinstance(thread[-1], BaseChatMessage): - source = thread[-1].source # name of the agent that just finished - content = thread[-1].to_model_text() - - # Safety check: only an active node can send a response - if source != "user": - if source not in self._active_nodes: - raise RuntimeError(f"Agent '{source}' is not currently active.") - - # Mark the node as no longer active (it just finished) - self._active_node_count[source] -= 1 - - if self._active_node_count[source] <= 0: - self._active_nodes.remove(source) - - source_node = self._graph.nodes[source] - - if source_node.edges: - # Case: conditional edges — only execute if condition is met - target_nodes_names: List[str] = [] - if source_node.edges[0].condition is not None: - target_nodes_names = [self._get_valid_target(source_node, content)] - other_nodes = [ - edge.target for edge in source_node.edges if edge.target != target_nodes_names[0] - ] - for other_node in other_nodes: - other_active_parents = [ - parent - for parent in self._parents[other_node] - if (parent != source and parent in self._active_nodes) - ] - if not other_active_parents: - self._pending_execution.pop(other_node) - else: - self._pending_execution[other_node] = other_active_parents - - else: - # Case: unconditional edges — mark this source as completed for all its children - target_nodes_names = [edge.target for edge in source_node.edges] - - for target in target_nodes_names: - self._pending_execution[target].append(source) + # Propagate the update to the children of the node. + for edge in self._edges[source]: + if edge.condition and edge.condition not in content: + continue + if self._activation[edge.target] == "all": + self._remaining[edge.target] -= 1 + if self._remaining[edge.target] == 0: + # If all parents are done, add to the ready queue. + self._ready.append(edge.target) else: - # TODO: Check if there are any usecase where the User can decide on the next speaker - pass + # If activation is any, add to the ready queue if not already enqueued. + if not self._enqueued_any[edge.target]: + self._ready.append(edge.target) + self._enqueued_any[edge.target] = True - # After updating _pending_execution, check which nodes are now unblocked - for node_name in list(self._pending_execution): - if self._use_default_start and not self._default_start_executed: - if node_name == self._graph.default_start_node: - next_speakers.add(node_name) - self._default_start_executed = True - break + async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]: + # Drain the ready queue for the next set of speakers. + speakers: List[str] = [] + while self._ready: + speaker = self._ready.popleft() + speakers.append(speaker) + # Reset the bookkeeping for the node that were selected. + if self._activation[speaker] == "any": + self._enqueued_any[speaker] = False + else: + self._remaining[speaker] = len(self._parents[speaker]) - if self._is_node_ready(node_name): - next_speakers.add(node_name) - node = self._graph.nodes[node_name] - if node.activation == "all": - self._pending_execution.pop(node_name) - else: - # If activation is any, remove the parent that just finished - if source is not None: - self._pending_execution[node_name] = [ - parent for parent in self._pending_execution[node_name] if parent != source - ] - - # If none of the other parents of this node are active, remove this node from pending execution - node_parents = self._parents[node_name] - if not any(parent in self._active_nodes for parent in node_parents): - self._pending_execution.pop(node_name) - - if not many: - break - - # Prepopulate children of next_speakers into _pending_execution - for node_name in next_speakers: - for edge in self._graph.nodes[node_name].edges: - if edge.target not in self._pending_execution: - self._pending_execution[edge.target] = [] - - # Mark newly selected speakers as active - for speaker in next_speakers: - if speaker not in self._active_nodes: - self._active_nodes.add(speaker) - - self._active_node_count[speaker] += 1 - - if not self._pending_execution and not next_speakers and not self._active_nodes: - next_speakers = set([_DIGRAPH_STOP_AGENT_NAME]) # Call the termination agent - - return list(next_speakers) - - async def select_speakers(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> List[str]: - return await self._select_speakers(thread) - - async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str: - """Select a speaker from the participants and return the - topic type of the selected speaker.""" - speakers = await self._select_speakers(thread, many=False) + # If there are no speakers, trigger the stop agent. if not speakers: - raise RuntimeError("No available speakers found.") - return speakers[0] + speakers = [_DIGRAPH_STOP_AGENT_NAME] + + return speakers async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: pass @@ -365,10 +282,9 @@ class GraphFlowManager(BaseGroupChatManager): state = { "message_thread": [message.dump() for message in self._message_thread], "current_turn": self._current_turn, - "active_nodes": list(self._active_nodes), - "pending_execution": self._pending_execution, - "active_node_count": self._active_node_count, - "default_start_executed": self._default_start_executed, + "remaining": dict(self._remaining), + "enqueued_any": dict(self._enqueued_any), + "ready": list(self._ready), } return state @@ -376,10 +292,9 @@ class GraphFlowManager(BaseGroupChatManager): """Restore execution state from saved data.""" self._message_thread = [self._message_factory.create(msg) for msg in state["message_thread"]] self._current_turn = state["current_turn"] - self._active_nodes = set(state["active_nodes"]) - self._pending_execution = state["pending_execution"] - self._active_node_count = state["active_node_count"] - self._default_start_executed = state.get("default_start_executed", False) + self._remaining = Counter(state["remaining"]) + self._enqueued_any = state["enqueued_any"] + self._ready = deque(state["ready"]) async def reset(self) -> None: """Reset execution state to the start of the graph.""" @@ -387,11 +302,9 @@ class GraphFlowManager(BaseGroupChatManager): self._message_thread.clear() if self._termination_condition: await self._termination_condition.reset() - - self._active_nodes = set() - self._active_node_count = {node: 0 for node in self._graph.nodes} - self._pending_execution = {node: [] for node in self._start_nodes} - self._default_start_executed = False + self._remaining = Counter({n: len(p) for n, p in self._parents.items()}) + self._enqueued_any = {n: False for n in self._graph.nodes} + self._ready = deque([n for n in self._graph.get_start_nodes()]) class _StopAgent(BaseChatAgent): diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index 34d1df7cf9..e921770013 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -2,7 +2,7 @@ import asyncio import json import logging import re -from typing import Any, Dict, List, Mapping +from typing import Any, Dict, List, Mapping, Sequence from autogen_core import AgentId, CancellationToken, DefaultTopicId, MessageContext, event, rpc from autogen_core.models import ( @@ -37,6 +37,7 @@ from .._events import ( GroupChatReset, GroupChatStart, GroupChatTermination, + SerializableException, ) from ._prompts import ( ORCHESTRATOR_FINAL_ANSWER_PROMPT, @@ -187,22 +188,29 @@ class MagenticOneOrchestrator(BaseGroupChatManager): @event async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: # type: ignore - delta: List[BaseAgentEvent | BaseChatMessage] = [] - if message.agent_response.inner_messages is not None: - for inner_message in message.agent_response.inner_messages: - delta.append(inner_message) - await self.update_message_thread([message.agent_response.chat_message]) - delta.append(message.agent_response.chat_message) + try: + delta: List[BaseAgentEvent | BaseChatMessage] = [] + if message.agent_response.inner_messages is not None: + for inner_message in message.agent_response.inner_messages: + delta.append(inner_message) + await self.update_message_thread([message.agent_response.chat_message]) + delta.append(message.agent_response.chat_message) - if self._termination_condition is not None: - stop_message = await self._termination_condition(delta) - if stop_message is not None: - # Reset the termination conditions. - await self._termination_condition.reset() - # Signal termination. - await self._signal_termination(stop_message) - return - await self._orchestrate_step(ctx.cancellation_token) + if self._termination_condition is not None: + stop_message = await self._termination_condition(delta) + if stop_message is not None: + # Reset the termination conditions. + await self._termination_condition.reset() + # Signal termination. + await self._signal_termination(stop_message) + return + + await self._orchestrate_step(ctx.cancellation_token) + except Exception as e: + error = SerializableException.from_exception(e) + await self._signal_termination_with_error(error) + # Raise the error to the runtime. + raise async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: pass @@ -229,9 +237,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager): self._n_rounds = orchestrator_state.n_rounds self._n_stalls = orchestrator_state.n_stalls - async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str: + async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str: """Not used in this orchestrator, we select next speaker in _orchestrate_step.""" - return "" + return [""] async def reset(self) -> None: """Reset the group chat manager.""" @@ -275,7 +283,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager): # Broadcast await self.publish_message( - GroupChatAgentResponse(agent_response=Response(chat_message=ledger_message)), + GroupChatAgentResponse(agent_response=Response(chat_message=ledger_message), agent_name=self._name), topic_id=DefaultTopicId(type=self._group_topic_type), ) @@ -389,7 +397,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager): # Broadcast it await self.publish_message( # Broadcast - GroupChatAgentResponse(agent_response=Response(chat_message=message)), + GroupChatAgentResponse(agent_response=Response(chat_message=message), agent_name=self._name), topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=cancellation_token, ) @@ -470,7 +478,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager): # Broadcast await self.publish_message( - GroupChatAgentResponse(agent_response=Response(chat_message=message)), + GroupChatAgentResponse(agent_response=Response(chat_message=message), agent_name=self._name), topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=cancellation_token, ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py index 6ea915511d..d6b43afb28 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Callable, List, Mapping +from typing import Any, Callable, List, Mapping, Sequence from autogen_core import AgentRuntime, Component, ComponentModel from pydantic import BaseModel @@ -69,8 +69,13 @@ class RoundRobinGroupChatManager(BaseGroupChatManager): self._current_turn = round_robin_state.current_turn self._next_speaker_index = round_robin_state.next_speaker_index - async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str: - """Select a speaker from the participants in a round-robin fashion.""" + async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str: + """Select a speaker from the participants in a round-robin fashion. + + .. note:: + + This method always returns a single speaker. + """ current_speaker_index = self._next_speaker_index self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_names) current_speaker = self._participant_names[current_speaker_index] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index d54b7aea29..a0d9e7732d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -150,10 +150,14 @@ class SelectorGroupChatManager(BaseGroupChatManager): base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)] await self._add_messages_to_context(self._model_context, base_chat_messages) - async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str: + async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str: """Selects the next speaker in a group chat using a ChatCompletion client, with the selector function as override if it returns a speaker name. + .. note:: + + This method always returns a single speaker name. + A key assumption is that the agent type is the same as the topic type, which we use as the agent name. """ # Use the selector function if provided. @@ -171,7 +175,7 @@ class SelectorGroupChatManager(BaseGroupChatManager): f"Expected one of: {self._participant_names}." ) # Skip the model based selection. - return speaker + return [speaker] # Use the candidate function to filter participants if provided if self._candidate_func is not None: @@ -211,7 +215,7 @@ class SelectorGroupChatManager(BaseGroupChatManager): agent_name = participants[0] self._previous_speaker = agent_name trace_logger.debug(f"Selected speaker: {agent_name}") - return agent_name + return [agent_name] def construct_message_history(self, message_history: List[LLMMessage]) -> str: # Construct the history of the conversation. diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py index 7842e4c468..3449940154 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Callable, List, Mapping +from typing import Any, Callable, List, Mapping, Sequence from autogen_core import AgentRuntime, Component, ComponentModel from pydantic import BaseModel @@ -79,17 +79,22 @@ class SwarmGroupChatManager(BaseGroupChatManager): await self._termination_condition.reset() self._current_speaker = self._participant_names[0] - async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str: + async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str: """Select a speaker from the participants based on handoff message. - Looks for the last handoff message in the thread to determine the next speaker.""" + Looks for the last handoff message in the thread to determine the next speaker. + + .. note:: + + This method always returns a single speaker. + """ if len(thread) == 0: - return self._current_speaker + return [self._current_speaker] for message in reversed(thread): if isinstance(message, HandoffMessage): self._current_speaker = message.target # The latest handoff message should always target a valid participant. assert self._current_speaker in self._participant_names - return self._current_speaker + return [self._current_speaker] return self._current_speaker async def save_state(self) -> Mapping[str, Any]: diff --git a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py index 7c8baf4b10..2f381528c0 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py @@ -1,6 +1,6 @@ import asyncio -from typing import Any, AsyncGenerator, Callable, Dict, List, Sequence, Set -from unittest.mock import AsyncMock, patch +from typing import AsyncGenerator, List, Sequence +from unittest.mock import patch import pytest import pytest_asyncio @@ -12,9 +12,8 @@ from autogen_agentchat.agents import ( PerSourceFilter, ) from autogen_agentchat.base import Response, TaskResult -from autogen_agentchat.conditions import MaxMessageTermination +from autogen_agentchat.conditions import MaxMessageTermination, SourceMatchTermination from autogen_agentchat.messages import BaseChatMessage, ChatMessage, MessageFactory, StopMessage, TextMessage -from autogen_agentchat.messages import BaseTextChatMessage as TextChatMessage from autogen_agentchat.teams import ( DiGraphBuilder, GraphFlow, @@ -269,61 +268,6 @@ def test_validate_graph_mixed_conditions() -> None: graph.graph_validate() -def test_get_valid_target() -> None: - node = DiGraphNode( - name="A", - edges=[DiGraphEdge(target="B", condition="approve"), DiGraphEdge(target="C", condition="reject")], - ) - manager = GraphFlowManager.__new__(GraphFlowManager) - - assert manager._get_valid_target(node, "please approve this") == "B" # pyright: ignore[reportPrivateUsage] - assert manager._get_valid_target(node, "i reject this") == "C" # pyright: ignore[reportPrivateUsage] - with pytest.raises(RuntimeError): - manager._get_valid_target(node, "unknown path") # pyright: ignore[reportPrivateUsage] - - -def test_is_node_ready_all_and_any() -> None: - graph = DiGraph( - nodes={ - "A": DiGraphNode(name="A", edges=[DiGraphEdge(target="C")]), - "B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]), - "C": DiGraphNode(name="C", edges=[], activation="all"), - } - ) - - manager = GraphFlowManager.__new__(GraphFlowManager) - manager._graph = graph # pyright: ignore[reportPrivateUsage] - manager._parents = graph.get_parents() # pyright: ignore[reportPrivateUsage] - - # === Test "all" activation === - # Case 1: No parent finished - manager._pending_execution = {"C": []} # pyright: ignore[reportPrivateUsage] - assert not manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage] - - # Case 2: One parent finished - manager._pending_execution = {"C": ["A"]} # pyright: ignore[reportPrivateUsage] - assert not manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage] - - # Case 3: All parents finished - manager._pending_execution = {"C": ["A", "B"]} # pyright: ignore[reportPrivateUsage] - assert manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage] - - # === Test "any" activation === - graph.nodes["C"].activation = "any" - - # Case 1: No parent finished - manager._pending_execution = {"C": []} # pyright: ignore[reportPrivateUsage] - assert not manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage] - - # Case 2: One parent finished - manager._pending_execution = {"C": ["B"]} # pyright: ignore[reportPrivateUsage] - assert manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage] - - # Case 3: All parents finished - manager._pending_execution = {"C": ["A", "B"]} # pyright: ignore[reportPrivateUsage] - assert manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage] - - @pytest.mark.asyncio async def test_invalid_digraph_manager_cycle_without_termination() -> None: """Test GraphManager raises error for cyclic graph without termination condition.""" @@ -359,170 +303,6 @@ async def test_invalid_digraph_manager_cycle_without_termination() -> None: ) -@pytest.fixture -def digraph_manager() -> Callable[..., GraphFlowManager]: - @patch( - "autogen_agentchat.teams._group_chat._base_group_chat_manager.BaseGroupChatManager.__init__", return_value=None - ) - def _create( - _: Any, - graph: DiGraph, - active_nodes: Set[str] | None = None, - thread: List[BaseAgentEvent | BaseChatMessage] | None = None, - pending: Dict[str, List[str]] | None = None, - ) -> GraphFlowManager: - manager = GraphFlowManager.__new__(GraphFlowManager) - manager._graph = graph # pyright: ignore[reportPrivateUsage] - manager._parents = graph.get_parents() # pyright: ignore[reportPrivateUsage] - manager._start_nodes = graph.get_start_nodes() # pyright: ignore[reportPrivateUsage] - manager._leaf_nodes = graph.get_leaf_nodes() # pyright: ignore[reportPrivateUsage] - manager._active_nodes = set(active_nodes or []) # pyright: ignore[reportPrivateUsage] - manager._active_node_count = {node: 0 for node in graph.nodes} # pyright: ignore[reportPrivateUsage] - manager._message_factory = MessageFactory() # pyright: ignore[reportPrivateUsage] - manager._message_thread = thread if thread is not None else [] # pyright: ignore[reportPrivateUsage] - manager._pending_execution = pending if pending is not None else {node: [] for node in graph.get_start_nodes()} # pyright: ignore[reportPrivateUsage] - manager._name = "test_manager" # pyright: ignore[reportPrivateUsage] - manager._use_default_start = False # pyright: ignore[reportPrivateUsage] - return manager - - return _create - - -# -------------------- Test: Sequential Flow -------------------- -@pytest.mark.asyncio -async def test_select_speakers_linear(digraph_manager: Callable[..., GraphFlowManager]) -> None: - graph = DiGraph( - nodes={ - "A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]), - "B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]), - "C": DiGraphNode(name="C", edges=[]), - } - ) - message_thread = [TextChatMessage(source="A", content="done", metadata={})] - manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []}) - - result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage] - assert result == ["B"] - assert "B" in manager._active_nodes # pyright: ignore[reportPrivateUsage] - - -# -------------------- Test: Parallel Fan-out -------------------- - - -@pytest.mark.asyncio -async def test_select_speakers_parallel(digraph_manager: Callable[..., GraphFlowManager]) -> None: - graph = DiGraph( - nodes={ - "A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]), - "B": DiGraphNode(name="B", edges=[]), - "C": DiGraphNode(name="C", edges=[]), - } - ) - message_thread = [TextChatMessage(source="A", content="done", metadata={})] - manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []}) - - result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage] - assert set(result) == {"B", "C"} - assert "B" in manager._active_nodes # pyright: ignore[reportPrivateUsage] - assert "C" in manager._active_nodes # pyright: ignore[reportPrivateUsage] - - -# -------------------- Test: Conditional Path -------------------- -@pytest.mark.asyncio -async def test_select_speakers_conditional(digraph_manager: Callable[..., GraphFlowManager]) -> None: - graph = DiGraph( - nodes={ - "A": DiGraphNode( - name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")] - ), - "B": DiGraphNode(name="B", edges=[]), - "C": DiGraphNode(name="C", edges=[]), - } - ) - message_thread = [TextChatMessage(source="A", content="no", metadata={})] - manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []}) - - result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage] - assert result == ["C"] - assert "C" in manager._active_nodes # pyright: ignore[reportPrivateUsage] - - -@pytest.mark.asyncio -async def test_select_speakers_from_start_nodes(digraph_manager: Callable[..., GraphFlowManager]) -> None: - graph = DiGraph( - nodes={ - "A": DiGraphNode(name="A", edges=[]), - "B": DiGraphNode(name="B", edges=[]), - } - ) - # No prior message — both are start nodes - manager = digraph_manager(graph=graph, active_nodes=set(), thread=[], pending={"A": [], "B": []}) - result = await manager.select_speakers([]) - assert set(result) == {"A", "B"} - - -@pytest.mark.asyncio -async def test_select_speakers_termination(digraph_manager: Callable[..., GraphFlowManager]) -> None: - graph = DiGraph( - nodes={ - "A": DiGraphNode(name="A", edges=[]), - } - ) - - # Create the manager and manually patch _signal_termination to track calls - manager = digraph_manager( - graph=graph, active_nodes={"A"}, thread=[TextChatMessage(source="A", content="done", metadata={})], pending={} - ) - manager._signal_termination = AsyncMock() # type: ignore[assignment] - - result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage] - - # No speakers left to run, so result should be empty - assert result == [_DIGRAPH_STOP_AGENT_NAME] - - -@pytest.mark.asyncio -async def test_select_speakers_conditional_all_activation( - digraph_manager: Callable[..., GraphFlowManager], -) -> None: - graph = DiGraph( - nodes={ - "A": DiGraphNode( - name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")] - ), - "B": DiGraphNode(name="B", edges=[], activation="all"), - "C": DiGraphNode(name="C", edges=[], activation="all"), - } - ) - message_thread = [TextChatMessage(source="A", content="no", metadata={})] - manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []}) - - result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage] - assert result == ["C"] - assert "C" in manager._active_nodes # pyright: ignore[reportPrivateUsage] - - -@pytest.mark.asyncio -async def test_select_speakers_conditional_any_activation( - digraph_manager: Callable[..., GraphFlowManager], -) -> None: - graph = DiGraph( - nodes={ - "A": DiGraphNode( - name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")] - ), - "B": DiGraphNode(name="B", edges=[], activation="any"), - "C": DiGraphNode(name="C", edges=[], activation="any"), - } - ) - message_thread = [TextChatMessage(source="A", content="yes", metadata={})] - manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []}) - - result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage] - assert result == ["B"] - assert "B" in manager._active_nodes # pyright: ignore[reportPrivateUsage] - - class _EchoAgent(BaseChatAgent): def __init__(self, name: str, description: str) -> None: super().__init__(name, description) @@ -1417,7 +1197,6 @@ async def test_graph_flow_serialize_deserialize() -> None: participants=builder.get_participants(), graph=builder.build(), runtime=None, - termination_condition=MaxMessageTermination(5), ) serialized = team.dump_component() @@ -1439,11 +1218,52 @@ async def test_graph_flow_serialize_deserialize() -> None: assert results.messages[1].source == "A" assert results.messages[1].content == "0" assert isinstance(results.messages[2], TextMessage) - assert results.messages[2].source == "A" - assert results.messages[2].content == "1" - assert isinstance(results.messages[3], TextMessage) - assert results.messages[3].source == "B" - assert results.messages[3].content == "0" + assert results.messages[2].source == "B" + assert results.messages[2].content == "0" assert isinstance(results.messages[-1], StopMessage) assert results.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME assert results.messages[-1].content == "Digraph execution is complete" + + +@pytest.mark.asyncio +async def test_graph_flow_stateful_pause_and_resume_with_termination() -> None: + client_a = ReplayChatCompletionClient(["A1", "A2"]) + client_b = ReplayChatCompletionClient(["B1"]) + + a = AssistantAgent("A", model_client=client_a) + b = AssistantAgent("B", model_client=client_b) + + builder = DiGraphBuilder() + builder.add_node(a).add_node(b) + builder.add_edge(a, b) + builder.set_entry_point(a) + + team = GraphFlow( + participants=builder.get_participants(), + graph=builder.build(), + runtime=None, + termination_condition=SourceMatchTermination(sources=["A"]), + ) + + result = await team.run(task="Start") + assert len(result.messages) == 2 + assert result.messages[0].source == "user" + assert result.messages[1].source == "A" + assert result.stop_reason is not None and result.stop_reason == "'A' answered" + + # Export state. + state = await team.save_state() + + # Load state into a new team. + new_team = GraphFlow( + participants=builder.get_participants(), + graph=builder.build(), + runtime=None, + ) + await new_team.load_state(state) + + # Resume. + result = await new_team.run() + assert len(result.messages) == 2 + assert result.messages[0].source == "B" + assert result.messages[1].source == _DIGRAPH_STOP_AGENT_NAME