Enable concurrent execution of agents in GraphFlow (#6545)

Support concurrent execution in `GraphFlow`:
- Updated `BaseGroupChatManager.select_speaker` to return a union of a
single string or a list of speaker name strings and added logics to
check for currently activated speakers and only proceed to select next
speakers when all activated speakers have finished.
- Updated existing teams (e.g., `SelectorGroupChat`) with the new
signature, while still returning a single speaker in their
implementations.
- Updated `GraphFlow` to support multiple speakers selected. 
- Refactored `GraphFlow` for less dictionary gymnastic by using a queue
and update using `update_message_thread`.

Example: a fan out graph:

```python
import asyncio

from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient

async def main():
    # Initialize agents with OpenAI model clients.
    model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
    agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.")
    agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.")
    agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Japanese.")

    # Create a directed graph with fan-out flow A -> (B, C).
    builder = DiGraphBuilder()
    builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
    builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
    graph = builder.build()

    # Create a GraphFlow team with the directed graph.
    team = GraphFlow(
        participants=[agent_a, agent_b, agent_c],
        graph=graph,
    )

    # Run the team and print the events.
    async for event in team.run_stream(task="Write a short story about a cat."):
        print(event)


asyncio.run(main())
```

Resolves:
#6541 
#6533
This commit is contained in:
Eric Zhu
2025-05-19 14:47:55 -07:00
committed by GitHub
parent 2eadef440e
commit f0b73441b6
9 changed files with 232 additions and 457 deletions

View File

@@ -2,7 +2,7 @@ import asyncio
from abc import ABC, abstractmethod
from typing import Any, List, Sequence
from autogen_core import DefaultTopicId, MessageContext, event, rpc
from autogen_core import CancellationToken, DefaultTopicId, MessageContext, event, rpc
from ...base import TerminationCondition
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, SelectSpeakerEvent, StopMessage
@@ -79,6 +79,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
self._current_turn = 0
self._message_factory = message_factory
self._emit_team_events = emit_team_events
self._active_speakers: List[str] = []
@rpc
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
@@ -122,58 +123,35 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
# Stop the group chat.
return
# Select a speaker to start/continue the conversation
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_name_future)
speaker_name = await speaker_name_future
if speaker_name not in self._participant_name_to_topic_type:
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
await self._log_speaker_selection(speaker_name)
# Send the message to the next speaker
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=ctx.cancellation_token,
)
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
self._message_thread.extend(messages)
# Select speakers to start/continue the conversation
await self._transition_to_next_speakers(ctx.cancellation_token)
@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
try:
# Append the message to the message thread and construct the delta.
# Construct the detla from the agent response.
delta: List[BaseAgentEvent | BaseChatMessage] = []
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
delta.append(inner_message)
delta.append(message.agent_response.chat_message)
# Append the messages to the message thread.
await self.update_message_thread(delta)
# Remove the agent from the active speakers list.
self._active_speakers.remove(message.agent_name)
if len(self._active_speakers) > 0:
# If there are still active speakers, return without doing anything.
return
# Check if the conversation should be terminated.
if await self._apply_termination_condition(delta, increment_turn_count=True):
# Stop the group chat.
return
# Select a speaker to continue the conversation.
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_name_future)
speaker_name = await speaker_name_future
if speaker_name not in self._participant_name_to_topic_type:
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
await self._log_speaker_selection(speaker_name)
# Send the message to the next speakers
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=ctx.cancellation_token,
)
# Select speakers to continue the conversation.
await self._transition_to_next_speakers(ctx.cancellation_token)
except Exception as e:
# Handle the exception and signal termination with an error.
error = SerializableException.from_exception(e)
@@ -181,6 +159,29 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
# Raise the exception to the runtime.
raise
async def _transition_to_next_speakers(self, cancellation_token: CancellationToken) -> None:
speaker_names_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
cancellation_token.link_future(speaker_names_future)
speaker_names = await speaker_names_future
if isinstance(speaker_names, str):
# If only one speaker is selected, convert it to a list.
speaker_names = [speaker_names]
for speaker_name in speaker_names:
if speaker_name not in self._participant_name_to_topic_type:
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
await self._log_speaker_selection(speaker_names)
# Send request to publish message to the next speakers
for speaker_name in speaker_names:
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=cancellation_token,
)
self._active_speakers.append(speaker_name)
async def _apply_termination_condition(
self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False
) -> bool:
@@ -216,9 +217,9 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
return True
return False
async def _log_speaker_selection(self, speaker_name: str) -> None:
async def _log_speaker_selection(self, speaker_names: List[str]) -> None:
"""Log the selected speaker to the output message queue."""
select_msg = SelectSpeakerEvent(content=[speaker_name], source=self._name)
select_msg = SelectSpeakerEvent(content=speaker_names, source=self._name)
if self._emit_team_events:
await self.publish_message(
GroupChatMessage(message=select_msg),
@@ -284,10 +285,26 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
"""
...
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
"""Update the message thread with the new messages.
This is called when the group chat receives a GroupChatStart or GroupChatAgentResponse event,
before calling the select_speakers method.
"""
self._message_thread.extend(messages)
@abstractmethod
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
"""Select a speaker from the participants and return the
topic type of the selected speaker."""
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
"""Select speakers from the participants and return the topic types of the selected speaker.
This is called when the group chat manager have received all responses from the participants
for a turn and is ready to select the next speakers for the next turn.
Args:
thread: The message thread of the group chat.
Returns:
A list of topic types of the selected speakers.
If only one speaker is selected, a single string is returned instead of a list.
"""
...
@abstractmethod

View File

@@ -89,7 +89,7 @@ class ChatAgentContainer(SequentialRoutedAgent):
# Publish the response to the group chat.
self._message_buffer.clear()
await self.publish_message(
GroupChatAgentResponse(agent_response=response),
GroupChatAgentResponse(agent_response=response, agent_name=self._agent.name),
topic_id=DefaultTopicId(type=self._parent_topic_type),
cancellation_token=ctx.cancellation_token,
)

View File

@@ -48,6 +48,9 @@ class GroupChatAgentResponse(BaseModel):
agent_response: Response
"""The response from an agent."""
agent_name: str
"""The name of the agent that produced the response."""
class GroupChatRequestPublish(BaseModel):
"""A request to publish a message to a group chat."""

View File

@@ -1,5 +1,6 @@
import asyncio
from typing import Any, Callable, Dict, List, Literal, Mapping, Sequence, Set
from collections import Counter, deque
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set
from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel
from pydantic import BaseModel
@@ -208,154 +209,70 @@ class GraphFlowManager(BaseGroupChatManager):
max_turns=max_turns,
message_factory=message_factory,
)
self._graph = graph
self._graph.graph_validate()
self._graph_has_cycles = self._graph.get_has_cycles()
if self._graph_has_cycles and self._termination_condition is None and self._max_turns is None:
graph.graph_validate()
if graph.get_has_cycles() and self._termination_condition is None and self._max_turns is None:
raise ValueError("A termination condition is required for cyclic graphs without a maximum turn limit.")
self._graph = graph
# Lookup table for incoming edges for each node.
self._parents = graph.get_parents()
# Lookup table for outgoing edges for each node.
self._edges: Dict[str, List[DiGraphEdge]] = {n: node.edges for n, node in graph.nodes.items()}
# Activation lookup table for each node.
self._activation: Dict[str, Literal["any", "all"]] = {n: node.activation for n, node in graph.nodes.items()}
self._use_default_start = self._graph.default_start_node is not None
self._default_start_executed = False
self._start_nodes = graph.get_start_nodes()
self._leaf_nodes = graph.get_leaf_nodes()
self._parents = graph.get_parents() # Parent node dependencies - helper dict to get all incoming edges
self._active_nodes: Set[str] = set() # Currently executing nodes
self._active_node_count: Dict[str, int] = {
node: 0 for node in graph.nodes
} # Number of times a node has been active
# === Mutable states for the graph execution ===
# Count the number of remaining parents to activate each node.
self._remaining: Counter[str] = Counter({n: len(p) for n, p in self._parents.items()})
# Lookup table for nodes that have been enqueued through an any activation.
# This is used to prevent re-adding the same node multiple times.
self._enqueued_any: Dict[str, bool] = {n: False for n in graph.nodes}
# Ready queue for nodes that are ready to execute, starting with the start nodes.
self._ready: Deque[str] = deque([n for n in graph.get_start_nodes()])
# These are nodes next in line for execution as one or more of their parent nodes have started execution.
# They execute when all their parent nodes have executed.
# Nodes are added to this dict when at least one of their parent nodes becomes active.
# Start nodes (no parents) are added to this dict at initialization as they are always ready to run.
self._pending_execution: Dict[str, List[str]] = {node: [] for node in graph.get_start_nodes()}
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
await super().update_message_thread(messages)
def _get_valid_target(self, node: DiGraphNode, content: str) -> str:
"""Check if a condition is met in the chat history."""
for edge in node.edges:
if edge.condition and edge.condition in content:
return edge.target
# Find the node that ran in the current turn.
message = messages[-1]
if message.source not in self._graph.nodes:
# Ignore messages from sources outside of the graph.
return
assert isinstance(message, BaseChatMessage)
source = message.source
content = message.to_model_text()
raise RuntimeError(f"Condition not met for node {node.name}. Content: {content}")
def _is_node_ready(self, node_name: str) -> bool:
"""Check if a node is ready to execute based on its parent nodes.
If activation is any then execute as soon as any parent has finished
If activation is all then execute only when all parents have finished
"""
node = self._graph.nodes[node_name]
if node.activation == "any":
return bool(self._pending_execution[node_name])
return all(parent in self._pending_execution[node_name] for parent in self._parents[node_name])
async def _select_speakers(self, thread: List[BaseAgentEvent | BaseChatMessage], many: bool = True) -> List[str]:
"""Select the next set of agents to execute based on DAG constraints."""
next_speakers: Set[str] = set()
source_node: DiGraphNode | None = None
source: str | None = None
if thread and isinstance(thread[-1], BaseChatMessage):
source = thread[-1].source # name of the agent that just finished
content = thread[-1].to_model_text()
# Safety check: only an active node can send a response
if source != "user":
if source not in self._active_nodes:
raise RuntimeError(f"Agent '{source}' is not currently active.")
# Mark the node as no longer active (it just finished)
self._active_node_count[source] -= 1
if self._active_node_count[source] <= 0:
self._active_nodes.remove(source)
source_node = self._graph.nodes[source]
if source_node.edges:
# Case: conditional edges — only execute if condition is met
target_nodes_names: List[str] = []
if source_node.edges[0].condition is not None:
target_nodes_names = [self._get_valid_target(source_node, content)]
other_nodes = [
edge.target for edge in source_node.edges if edge.target != target_nodes_names[0]
]
for other_node in other_nodes:
other_active_parents = [
parent
for parent in self._parents[other_node]
if (parent != source and parent in self._active_nodes)
]
if not other_active_parents:
self._pending_execution.pop(other_node)
else:
self._pending_execution[other_node] = other_active_parents
else:
# Case: unconditional edges — mark this source as completed for all its children
target_nodes_names = [edge.target for edge in source_node.edges]
for target in target_nodes_names:
self._pending_execution[target].append(source)
# Propagate the update to the children of the node.
for edge in self._edges[source]:
if edge.condition and edge.condition not in content:
continue
if self._activation[edge.target] == "all":
self._remaining[edge.target] -= 1
if self._remaining[edge.target] == 0:
# If all parents are done, add to the ready queue.
self._ready.append(edge.target)
else:
# TODO: Check if there are any usecase where the User can decide on the next speaker
pass
# If activation is any, add to the ready queue if not already enqueued.
if not self._enqueued_any[edge.target]:
self._ready.append(edge.target)
self._enqueued_any[edge.target] = True
# After updating _pending_execution, check which nodes are now unblocked
for node_name in list(self._pending_execution):
if self._use_default_start and not self._default_start_executed:
if node_name == self._graph.default_start_node:
next_speakers.add(node_name)
self._default_start_executed = True
break
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]:
# Drain the ready queue for the next set of speakers.
speakers: List[str] = []
while self._ready:
speaker = self._ready.popleft()
speakers.append(speaker)
# Reset the bookkeeping for the node that were selected.
if self._activation[speaker] == "any":
self._enqueued_any[speaker] = False
else:
self._remaining[speaker] = len(self._parents[speaker])
if self._is_node_ready(node_name):
next_speakers.add(node_name)
node = self._graph.nodes[node_name]
if node.activation == "all":
self._pending_execution.pop(node_name)
else:
# If activation is any, remove the parent that just finished
if source is not None:
self._pending_execution[node_name] = [
parent for parent in self._pending_execution[node_name] if parent != source
]
# If none of the other parents of this node are active, remove this node from pending execution
node_parents = self._parents[node_name]
if not any(parent in self._active_nodes for parent in node_parents):
self._pending_execution.pop(node_name)
if not many:
break
# Prepopulate children of next_speakers into _pending_execution
for node_name in next_speakers:
for edge in self._graph.nodes[node_name].edges:
if edge.target not in self._pending_execution:
self._pending_execution[edge.target] = []
# Mark newly selected speakers as active
for speaker in next_speakers:
if speaker not in self._active_nodes:
self._active_nodes.add(speaker)
self._active_node_count[speaker] += 1
if not self._pending_execution and not next_speakers and not self._active_nodes:
next_speakers = set([_DIGRAPH_STOP_AGENT_NAME]) # Call the termination agent
return list(next_speakers)
async def select_speakers(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> List[str]:
return await self._select_speakers(thread)
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
"""Select a speaker from the participants and return the
topic type of the selected speaker."""
speakers = await self._select_speakers(thread, many=False)
# If there are no speakers, trigger the stop agent.
if not speakers:
raise RuntimeError("No available speakers found.")
return speakers[0]
speakers = [_DIGRAPH_STOP_AGENT_NAME]
return speakers
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass
@@ -365,10 +282,9 @@ class GraphFlowManager(BaseGroupChatManager):
state = {
"message_thread": [message.dump() for message in self._message_thread],
"current_turn": self._current_turn,
"active_nodes": list(self._active_nodes),
"pending_execution": self._pending_execution,
"active_node_count": self._active_node_count,
"default_start_executed": self._default_start_executed,
"remaining": dict(self._remaining),
"enqueued_any": dict(self._enqueued_any),
"ready": list(self._ready),
}
return state
@@ -376,10 +292,9 @@ class GraphFlowManager(BaseGroupChatManager):
"""Restore execution state from saved data."""
self._message_thread = [self._message_factory.create(msg) for msg in state["message_thread"]]
self._current_turn = state["current_turn"]
self._active_nodes = set(state["active_nodes"])
self._pending_execution = state["pending_execution"]
self._active_node_count = state["active_node_count"]
self._default_start_executed = state.get("default_start_executed", False)
self._remaining = Counter(state["remaining"])
self._enqueued_any = state["enqueued_any"]
self._ready = deque(state["ready"])
async def reset(self) -> None:
"""Reset execution state to the start of the graph."""
@@ -387,11 +302,9 @@ class GraphFlowManager(BaseGroupChatManager):
self._message_thread.clear()
if self._termination_condition:
await self._termination_condition.reset()
self._active_nodes = set()
self._active_node_count = {node: 0 for node in self._graph.nodes}
self._pending_execution = {node: [] for node in self._start_nodes}
self._default_start_executed = False
self._remaining = Counter({n: len(p) for n, p in self._parents.items()})
self._enqueued_any = {n: False for n in self._graph.nodes}
self._ready = deque([n for n in self._graph.get_start_nodes()])
class _StopAgent(BaseChatAgent):

View File

@@ -2,7 +2,7 @@ import asyncio
import json
import logging
import re
from typing import Any, Dict, List, Mapping
from typing import Any, Dict, List, Mapping, Sequence
from autogen_core import AgentId, CancellationToken, DefaultTopicId, MessageContext, event, rpc
from autogen_core.models import (
@@ -37,6 +37,7 @@ from .._events import (
GroupChatReset,
GroupChatStart,
GroupChatTermination,
SerializableException,
)
from ._prompts import (
ORCHESTRATOR_FINAL_ANSWER_PROMPT,
@@ -187,22 +188,29 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: # type: ignore
delta: List[BaseAgentEvent | BaseChatMessage] = []
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
delta.append(inner_message)
await self.update_message_thread([message.agent_response.chat_message])
delta.append(message.agent_response.chat_message)
try:
delta: List[BaseAgentEvent | BaseChatMessage] = []
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
delta.append(inner_message)
await self.update_message_thread([message.agent_response.chat_message])
delta.append(message.agent_response.chat_message)
if self._termination_condition is not None:
stop_message = await self._termination_condition(delta)
if stop_message is not None:
# Reset the termination conditions.
await self._termination_condition.reset()
# Signal termination.
await self._signal_termination(stop_message)
return
await self._orchestrate_step(ctx.cancellation_token)
if self._termination_condition is not None:
stop_message = await self._termination_condition(delta)
if stop_message is not None:
# Reset the termination conditions.
await self._termination_condition.reset()
# Signal termination.
await self._signal_termination(stop_message)
return
await self._orchestrate_step(ctx.cancellation_token)
except Exception as e:
error = SerializableException.from_exception(e)
await self._signal_termination_with_error(error)
# Raise the error to the runtime.
raise
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass
@@ -229,9 +237,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
self._n_rounds = orchestrator_state.n_rounds
self._n_stalls = orchestrator_state.n_stalls
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
"""Not used in this orchestrator, we select next speaker in _orchestrate_step."""
return ""
return [""]
async def reset(self) -> None:
"""Reset the group chat manager."""
@@ -275,7 +283,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
# Broadcast
await self.publish_message(
GroupChatAgentResponse(agent_response=Response(chat_message=ledger_message)),
GroupChatAgentResponse(agent_response=Response(chat_message=ledger_message), agent_name=self._name),
topic_id=DefaultTopicId(type=self._group_topic_type),
)
@@ -389,7 +397,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
# Broadcast it
await self.publish_message( # Broadcast
GroupChatAgentResponse(agent_response=Response(chat_message=message)),
GroupChatAgentResponse(agent_response=Response(chat_message=message), agent_name=self._name),
topic_id=DefaultTopicId(type=self._group_topic_type),
cancellation_token=cancellation_token,
)
@@ -470,7 +478,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
# Broadcast
await self.publish_message(
GroupChatAgentResponse(agent_response=Response(chat_message=message)),
GroupChatAgentResponse(agent_response=Response(chat_message=message), agent_name=self._name),
topic_id=DefaultTopicId(type=self._group_topic_type),
cancellation_token=cancellation_token,
)

View File

@@ -1,5 +1,5 @@
import asyncio
from typing import Any, Callable, List, Mapping
from typing import Any, Callable, List, Mapping, Sequence
from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel
@@ -69,8 +69,13 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
self._current_turn = round_robin_state.current_turn
self._next_speaker_index = round_robin_state.next_speaker_index
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
"""Select a speaker from the participants in a round-robin fashion."""
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
"""Select a speaker from the participants in a round-robin fashion.
.. note::
This method always returns a single speaker.
"""
current_speaker_index = self._next_speaker_index
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_names)
current_speaker = self._participant_names[current_speaker_index]

View File

@@ -150,10 +150,14 @@ class SelectorGroupChatManager(BaseGroupChatManager):
base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)]
await self._add_messages_to_context(self._model_context, base_chat_messages)
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
"""Selects the next speaker in a group chat using a ChatCompletion client,
with the selector function as override if it returns a speaker name.
.. note::
This method always returns a single speaker name.
A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
"""
# Use the selector function if provided.
@@ -171,7 +175,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
f"Expected one of: {self._participant_names}."
)
# Skip the model based selection.
return speaker
return [speaker]
# Use the candidate function to filter participants if provided
if self._candidate_func is not None:
@@ -211,7 +215,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
agent_name = participants[0]
self._previous_speaker = agent_name
trace_logger.debug(f"Selected speaker: {agent_name}")
return agent_name
return [agent_name]
def construct_message_history(self, message_history: List[LLMMessage]) -> str:
# Construct the history of the conversation.

View File

@@ -1,5 +1,5 @@
import asyncio
from typing import Any, Callable, List, Mapping
from typing import Any, Callable, List, Mapping, Sequence
from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel
@@ -79,17 +79,22 @@ class SwarmGroupChatManager(BaseGroupChatManager):
await self._termination_condition.reset()
self._current_speaker = self._participant_names[0]
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
"""Select a speaker from the participants based on handoff message.
Looks for the last handoff message in the thread to determine the next speaker."""
Looks for the last handoff message in the thread to determine the next speaker.
.. note::
This method always returns a single speaker.
"""
if len(thread) == 0:
return self._current_speaker
return [self._current_speaker]
for message in reversed(thread):
if isinstance(message, HandoffMessage):
self._current_speaker = message.target
# The latest handoff message should always target a valid participant.
assert self._current_speaker in self._participant_names
return self._current_speaker
return [self._current_speaker]
return self._current_speaker
async def save_state(self) -> Mapping[str, Any]:

View File

@@ -1,6 +1,6 @@
import asyncio
from typing import Any, AsyncGenerator, Callable, Dict, List, Sequence, Set
from unittest.mock import AsyncMock, patch
from typing import AsyncGenerator, List, Sequence
from unittest.mock import patch
import pytest
import pytest_asyncio
@@ -12,9 +12,8 @@ from autogen_agentchat.agents import (
PerSourceFilter,
)
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.conditions import MaxMessageTermination, SourceMatchTermination
from autogen_agentchat.messages import BaseChatMessage, ChatMessage, MessageFactory, StopMessage, TextMessage
from autogen_agentchat.messages import BaseTextChatMessage as TextChatMessage
from autogen_agentchat.teams import (
DiGraphBuilder,
GraphFlow,
@@ -269,61 +268,6 @@ def test_validate_graph_mixed_conditions() -> None:
graph.graph_validate()
def test_get_valid_target() -> None:
node = DiGraphNode(
name="A",
edges=[DiGraphEdge(target="B", condition="approve"), DiGraphEdge(target="C", condition="reject")],
)
manager = GraphFlowManager.__new__(GraphFlowManager)
assert manager._get_valid_target(node, "please approve this") == "B" # pyright: ignore[reportPrivateUsage]
assert manager._get_valid_target(node, "i reject this") == "C" # pyright: ignore[reportPrivateUsage]
with pytest.raises(RuntimeError):
manager._get_valid_target(node, "unknown path") # pyright: ignore[reportPrivateUsage]
def test_is_node_ready_all_and_any() -> None:
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[], activation="all"),
}
)
manager = GraphFlowManager.__new__(GraphFlowManager)
manager._graph = graph # pyright: ignore[reportPrivateUsage]
manager._parents = graph.get_parents() # pyright: ignore[reportPrivateUsage]
# === Test "all" activation ===
# Case 1: No parent finished
manager._pending_execution = {"C": []} # pyright: ignore[reportPrivateUsage]
assert not manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage]
# Case 2: One parent finished
manager._pending_execution = {"C": ["A"]} # pyright: ignore[reportPrivateUsage]
assert not manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage]
# Case 3: All parents finished
manager._pending_execution = {"C": ["A", "B"]} # pyright: ignore[reportPrivateUsage]
assert manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage]
# === Test "any" activation ===
graph.nodes["C"].activation = "any"
# Case 1: No parent finished
manager._pending_execution = {"C": []} # pyright: ignore[reportPrivateUsage]
assert not manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage]
# Case 2: One parent finished
manager._pending_execution = {"C": ["B"]} # pyright: ignore[reportPrivateUsage]
assert manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage]
# Case 3: All parents finished
manager._pending_execution = {"C": ["A", "B"]} # pyright: ignore[reportPrivateUsage]
assert manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage]
@pytest.mark.asyncio
async def test_invalid_digraph_manager_cycle_without_termination() -> None:
"""Test GraphManager raises error for cyclic graph without termination condition."""
@@ -359,170 +303,6 @@ async def test_invalid_digraph_manager_cycle_without_termination() -> None:
)
@pytest.fixture
def digraph_manager() -> Callable[..., GraphFlowManager]:
@patch(
"autogen_agentchat.teams._group_chat._base_group_chat_manager.BaseGroupChatManager.__init__", return_value=None
)
def _create(
_: Any,
graph: DiGraph,
active_nodes: Set[str] | None = None,
thread: List[BaseAgentEvent | BaseChatMessage] | None = None,
pending: Dict[str, List[str]] | None = None,
) -> GraphFlowManager:
manager = GraphFlowManager.__new__(GraphFlowManager)
manager._graph = graph # pyright: ignore[reportPrivateUsage]
manager._parents = graph.get_parents() # pyright: ignore[reportPrivateUsage]
manager._start_nodes = graph.get_start_nodes() # pyright: ignore[reportPrivateUsage]
manager._leaf_nodes = graph.get_leaf_nodes() # pyright: ignore[reportPrivateUsage]
manager._active_nodes = set(active_nodes or []) # pyright: ignore[reportPrivateUsage]
manager._active_node_count = {node: 0 for node in graph.nodes} # pyright: ignore[reportPrivateUsage]
manager._message_factory = MessageFactory() # pyright: ignore[reportPrivateUsage]
manager._message_thread = thread if thread is not None else [] # pyright: ignore[reportPrivateUsage]
manager._pending_execution = pending if pending is not None else {node: [] for node in graph.get_start_nodes()} # pyright: ignore[reportPrivateUsage]
manager._name = "test_manager" # pyright: ignore[reportPrivateUsage]
manager._use_default_start = False # pyright: ignore[reportPrivateUsage]
return manager
return _create
# -------------------- Test: Sequential Flow --------------------
@pytest.mark.asyncio
async def test_select_speakers_linear(digraph_manager: Callable[..., GraphFlowManager]) -> None:
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
message_thread = [TextChatMessage(source="A", content="done", metadata={})]
manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []})
result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage]
assert result == ["B"]
assert "B" in manager._active_nodes # pyright: ignore[reportPrivateUsage]
# -------------------- Test: Parallel Fan-out --------------------
@pytest.mark.asyncio
async def test_select_speakers_parallel(digraph_manager: Callable[..., GraphFlowManager]) -> None:
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[]),
}
)
message_thread = [TextChatMessage(source="A", content="done", metadata={})]
manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []})
result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage]
assert set(result) == {"B", "C"}
assert "B" in manager._active_nodes # pyright: ignore[reportPrivateUsage]
assert "C" in manager._active_nodes # pyright: ignore[reportPrivateUsage]
# -------------------- Test: Conditional Path --------------------
@pytest.mark.asyncio
async def test_select_speakers_conditional(digraph_manager: Callable[..., GraphFlowManager]) -> None:
graph = DiGraph(
nodes={
"A": DiGraphNode(
name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")]
),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[]),
}
)
message_thread = [TextChatMessage(source="A", content="no", metadata={})]
manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []})
result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage]
assert result == ["C"]
assert "C" in manager._active_nodes # pyright: ignore[reportPrivateUsage]
@pytest.mark.asyncio
async def test_select_speakers_from_start_nodes(digraph_manager: Callable[..., GraphFlowManager]) -> None:
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[]),
"B": DiGraphNode(name="B", edges=[]),
}
)
# No prior message — both are start nodes
manager = digraph_manager(graph=graph, active_nodes=set(), thread=[], pending={"A": [], "B": []})
result = await manager.select_speakers([])
assert set(result) == {"A", "B"}
@pytest.mark.asyncio
async def test_select_speakers_termination(digraph_manager: Callable[..., GraphFlowManager]) -> None:
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[]),
}
)
# Create the manager and manually patch _signal_termination to track calls
manager = digraph_manager(
graph=graph, active_nodes={"A"}, thread=[TextChatMessage(source="A", content="done", metadata={})], pending={}
)
manager._signal_termination = AsyncMock() # type: ignore[assignment]
result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage]
# No speakers left to run, so result should be empty
assert result == [_DIGRAPH_STOP_AGENT_NAME]
@pytest.mark.asyncio
async def test_select_speakers_conditional_all_activation(
digraph_manager: Callable[..., GraphFlowManager],
) -> None:
graph = DiGraph(
nodes={
"A": DiGraphNode(
name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")]
),
"B": DiGraphNode(name="B", edges=[], activation="all"),
"C": DiGraphNode(name="C", edges=[], activation="all"),
}
)
message_thread = [TextChatMessage(source="A", content="no", metadata={})]
manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []})
result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage]
assert result == ["C"]
assert "C" in manager._active_nodes # pyright: ignore[reportPrivateUsage]
@pytest.mark.asyncio
async def test_select_speakers_conditional_any_activation(
digraph_manager: Callable[..., GraphFlowManager],
) -> None:
graph = DiGraph(
nodes={
"A": DiGraphNode(
name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")]
),
"B": DiGraphNode(name="B", edges=[], activation="any"),
"C": DiGraphNode(name="C", edges=[], activation="any"),
}
)
message_thread = [TextChatMessage(source="A", content="yes", metadata={})]
manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []})
result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage]
assert result == ["B"]
assert "B" in manager._active_nodes # pyright: ignore[reportPrivateUsage]
class _EchoAgent(BaseChatAgent):
def __init__(self, name: str, description: str) -> None:
super().__init__(name, description)
@@ -1417,7 +1197,6 @@ async def test_graph_flow_serialize_deserialize() -> None:
participants=builder.get_participants(),
graph=builder.build(),
runtime=None,
termination_condition=MaxMessageTermination(5),
)
serialized = team.dump_component()
@@ -1439,11 +1218,52 @@ async def test_graph_flow_serialize_deserialize() -> None:
assert results.messages[1].source == "A"
assert results.messages[1].content == "0"
assert isinstance(results.messages[2], TextMessage)
assert results.messages[2].source == "A"
assert results.messages[2].content == "1"
assert isinstance(results.messages[3], TextMessage)
assert results.messages[3].source == "B"
assert results.messages[3].content == "0"
assert results.messages[2].source == "B"
assert results.messages[2].content == "0"
assert isinstance(results.messages[-1], StopMessage)
assert results.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert results.messages[-1].content == "Digraph execution is complete"
@pytest.mark.asyncio
async def test_graph_flow_stateful_pause_and_resume_with_termination() -> None:
client_a = ReplayChatCompletionClient(["A1", "A2"])
client_b = ReplayChatCompletionClient(["B1"])
a = AssistantAgent("A", model_client=client_a)
b = AssistantAgent("B", model_client=client_b)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b)
builder.add_edge(a, b)
builder.set_entry_point(a)
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=None,
termination_condition=SourceMatchTermination(sources=["A"]),
)
result = await team.run(task="Start")
assert len(result.messages) == 2
assert result.messages[0].source == "user"
assert result.messages[1].source == "A"
assert result.stop_reason is not None and result.stop_reason == "'A' answered"
# Export state.
state = await team.save_state()
# Load state into a new team.
new_team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=None,
)
await new_team.load_state(state)
# Resume.
result = await new_team.run()
assert len(result.messages) == 2
assert result.messages[0].source == "B"
assert result.messages[1].source == _DIGRAPH_STOP_AGENT_NAME