mirror of
https://github.com/microsoft/autogen.git
synced 2026-05-13 03:00:55 -04:00
Enable concurrent execution of agents in GraphFlow (#6545)
Support concurrent execution in `GraphFlow`:
- Updated `BaseGroupChatManager.select_speaker` to return a union of a
single string or a list of speaker name strings and added logics to
check for currently activated speakers and only proceed to select next
speakers when all activated speakers have finished.
- Updated existing teams (e.g., `SelectorGroupChat`) with the new
signature, while still returning a single speaker in their
implementations.
- Updated `GraphFlow` to support multiple speakers selected.
- Refactored `GraphFlow` for less dictionary gymnastic by using a queue
and update using `update_message_thread`.
Example: a fan out graph:
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.")
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Japanese.")
# Create a directed graph with fan-out flow A -> (B, C).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
)
# Run the team and print the events.
async for event in team.run_stream(task="Write a short story about a cat."):
print(event)
asyncio.run(main())
```
Resolves:
#6541
#6533
This commit is contained in:
@@ -2,7 +2,7 @@ import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Sequence
|
||||
|
||||
from autogen_core import DefaultTopicId, MessageContext, event, rpc
|
||||
from autogen_core import CancellationToken, DefaultTopicId, MessageContext, event, rpc
|
||||
|
||||
from ...base import TerminationCondition
|
||||
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, SelectSpeakerEvent, StopMessage
|
||||
@@ -79,6 +79,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
self._current_turn = 0
|
||||
self._message_factory = message_factory
|
||||
self._emit_team_events = emit_team_events
|
||||
self._active_speakers: List[str] = []
|
||||
|
||||
@rpc
|
||||
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
|
||||
@@ -122,58 +123,35 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Select a speaker to start/continue the conversation
|
||||
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
|
||||
# Link the select speaker future to the cancellation token.
|
||||
ctx.cancellation_token.link_future(speaker_name_future)
|
||||
speaker_name = await speaker_name_future
|
||||
if speaker_name not in self._participant_name_to_topic_type:
|
||||
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
|
||||
await self._log_speaker_selection(speaker_name)
|
||||
|
||||
# Send the message to the next speaker
|
||||
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
|
||||
await self.publish_message(
|
||||
GroupChatRequestPublish(),
|
||||
topic_id=DefaultTopicId(type=speaker_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
|
||||
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
|
||||
self._message_thread.extend(messages)
|
||||
# Select speakers to start/continue the conversation
|
||||
await self._transition_to_next_speakers(ctx.cancellation_token)
|
||||
|
||||
@event
|
||||
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
|
||||
try:
|
||||
# Append the message to the message thread and construct the delta.
|
||||
# Construct the detla from the agent response.
|
||||
delta: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if message.agent_response.inner_messages is not None:
|
||||
for inner_message in message.agent_response.inner_messages:
|
||||
delta.append(inner_message)
|
||||
delta.append(message.agent_response.chat_message)
|
||||
|
||||
# Append the messages to the message thread.
|
||||
await self.update_message_thread(delta)
|
||||
|
||||
# Remove the agent from the active speakers list.
|
||||
self._active_speakers.remove(message.agent_name)
|
||||
if len(self._active_speakers) > 0:
|
||||
# If there are still active speakers, return without doing anything.
|
||||
return
|
||||
|
||||
# Check if the conversation should be terminated.
|
||||
if await self._apply_termination_condition(delta, increment_turn_count=True):
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Select a speaker to continue the conversation.
|
||||
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
|
||||
# Link the select speaker future to the cancellation token.
|
||||
ctx.cancellation_token.link_future(speaker_name_future)
|
||||
speaker_name = await speaker_name_future
|
||||
if speaker_name not in self._participant_name_to_topic_type:
|
||||
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
|
||||
await self._log_speaker_selection(speaker_name)
|
||||
|
||||
# Send the message to the next speakers
|
||||
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
|
||||
await self.publish_message(
|
||||
GroupChatRequestPublish(),
|
||||
topic_id=DefaultTopicId(type=speaker_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
# Select speakers to continue the conversation.
|
||||
await self._transition_to_next_speakers(ctx.cancellation_token)
|
||||
except Exception as e:
|
||||
# Handle the exception and signal termination with an error.
|
||||
error = SerializableException.from_exception(e)
|
||||
@@ -181,6 +159,29 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
# Raise the exception to the runtime.
|
||||
raise
|
||||
|
||||
async def _transition_to_next_speakers(self, cancellation_token: CancellationToken) -> None:
|
||||
speaker_names_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
|
||||
# Link the select speaker future to the cancellation token.
|
||||
cancellation_token.link_future(speaker_names_future)
|
||||
speaker_names = await speaker_names_future
|
||||
if isinstance(speaker_names, str):
|
||||
# If only one speaker is selected, convert it to a list.
|
||||
speaker_names = [speaker_names]
|
||||
for speaker_name in speaker_names:
|
||||
if speaker_name not in self._participant_name_to_topic_type:
|
||||
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
|
||||
await self._log_speaker_selection(speaker_names)
|
||||
|
||||
# Send request to publish message to the next speakers
|
||||
for speaker_name in speaker_names:
|
||||
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
|
||||
await self.publish_message(
|
||||
GroupChatRequestPublish(),
|
||||
topic_id=DefaultTopicId(type=speaker_topic_type),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
self._active_speakers.append(speaker_name)
|
||||
|
||||
async def _apply_termination_condition(
|
||||
self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False
|
||||
) -> bool:
|
||||
@@ -216,9 +217,9 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _log_speaker_selection(self, speaker_name: str) -> None:
|
||||
async def _log_speaker_selection(self, speaker_names: List[str]) -> None:
|
||||
"""Log the selected speaker to the output message queue."""
|
||||
select_msg = SelectSpeakerEvent(content=[speaker_name], source=self._name)
|
||||
select_msg = SelectSpeakerEvent(content=speaker_names, source=self._name)
|
||||
if self._emit_team_events:
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=select_msg),
|
||||
@@ -284,10 +285,26 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
|
||||
"""Update the message thread with the new messages.
|
||||
This is called when the group chat receives a GroupChatStart or GroupChatAgentResponse event,
|
||||
before calling the select_speakers method.
|
||||
"""
|
||||
self._message_thread.extend(messages)
|
||||
|
||||
@abstractmethod
|
||||
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
|
||||
"""Select a speaker from the participants and return the
|
||||
topic type of the selected speaker."""
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Select speakers from the participants and return the topic types of the selected speaker.
|
||||
This is called when the group chat manager have received all responses from the participants
|
||||
for a turn and is ready to select the next speakers for the next turn.
|
||||
|
||||
Args:
|
||||
thread: The message thread of the group chat.
|
||||
|
||||
Returns:
|
||||
A list of topic types of the selected speakers.
|
||||
If only one speaker is selected, a single string is returned instead of a list.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -89,7 +89,7 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
||||
# Publish the response to the group chat.
|
||||
self._message_buffer.clear()
|
||||
await self.publish_message(
|
||||
GroupChatAgentResponse(agent_response=response),
|
||||
GroupChatAgentResponse(agent_response=response, agent_name=self._agent.name),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
|
||||
@@ -48,6 +48,9 @@ class GroupChatAgentResponse(BaseModel):
|
||||
agent_response: Response
|
||||
"""The response from an agent."""
|
||||
|
||||
agent_name: str
|
||||
"""The name of the agent that produced the response."""
|
||||
|
||||
|
||||
class GroupChatRequestPublish(BaseModel):
|
||||
"""A request to publish a message to a group chat."""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, Dict, List, Literal, Mapping, Sequence, Set
|
||||
from collections import Counter, deque
|
||||
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set
|
||||
|
||||
from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
@@ -208,154 +209,70 @@ class GraphFlowManager(BaseGroupChatManager):
|
||||
max_turns=max_turns,
|
||||
message_factory=message_factory,
|
||||
)
|
||||
self._graph = graph
|
||||
self._graph.graph_validate()
|
||||
self._graph_has_cycles = self._graph.get_has_cycles()
|
||||
if self._graph_has_cycles and self._termination_condition is None and self._max_turns is None:
|
||||
graph.graph_validate()
|
||||
if graph.get_has_cycles() and self._termination_condition is None and self._max_turns is None:
|
||||
raise ValueError("A termination condition is required for cyclic graphs without a maximum turn limit.")
|
||||
self._graph = graph
|
||||
# Lookup table for incoming edges for each node.
|
||||
self._parents = graph.get_parents()
|
||||
# Lookup table for outgoing edges for each node.
|
||||
self._edges: Dict[str, List[DiGraphEdge]] = {n: node.edges for n, node in graph.nodes.items()}
|
||||
# Activation lookup table for each node.
|
||||
self._activation: Dict[str, Literal["any", "all"]] = {n: node.activation for n, node in graph.nodes.items()}
|
||||
|
||||
self._use_default_start = self._graph.default_start_node is not None
|
||||
self._default_start_executed = False
|
||||
self._start_nodes = graph.get_start_nodes()
|
||||
self._leaf_nodes = graph.get_leaf_nodes()
|
||||
self._parents = graph.get_parents() # Parent node dependencies - helper dict to get all incoming edges
|
||||
self._active_nodes: Set[str] = set() # Currently executing nodes
|
||||
self._active_node_count: Dict[str, int] = {
|
||||
node: 0 for node in graph.nodes
|
||||
} # Number of times a node has been active
|
||||
# === Mutable states for the graph execution ===
|
||||
# Count the number of remaining parents to activate each node.
|
||||
self._remaining: Counter[str] = Counter({n: len(p) for n, p in self._parents.items()})
|
||||
# Lookup table for nodes that have been enqueued through an any activation.
|
||||
# This is used to prevent re-adding the same node multiple times.
|
||||
self._enqueued_any: Dict[str, bool] = {n: False for n in graph.nodes}
|
||||
# Ready queue for nodes that are ready to execute, starting with the start nodes.
|
||||
self._ready: Deque[str] = deque([n for n in graph.get_start_nodes()])
|
||||
|
||||
# These are nodes next in line for execution as one or more of their parent nodes have started execution.
|
||||
# They execute when all their parent nodes have executed.
|
||||
# Nodes are added to this dict when at least one of their parent nodes becomes active.
|
||||
# Start nodes (no parents) are added to this dict at initialization as they are always ready to run.
|
||||
self._pending_execution: Dict[str, List[str]] = {node: [] for node in graph.get_start_nodes()}
|
||||
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
|
||||
await super().update_message_thread(messages)
|
||||
|
||||
def _get_valid_target(self, node: DiGraphNode, content: str) -> str:
|
||||
"""Check if a condition is met in the chat history."""
|
||||
for edge in node.edges:
|
||||
if edge.condition and edge.condition in content:
|
||||
return edge.target
|
||||
# Find the node that ran in the current turn.
|
||||
message = messages[-1]
|
||||
if message.source not in self._graph.nodes:
|
||||
# Ignore messages from sources outside of the graph.
|
||||
return
|
||||
assert isinstance(message, BaseChatMessage)
|
||||
source = message.source
|
||||
content = message.to_model_text()
|
||||
|
||||
raise RuntimeError(f"Condition not met for node {node.name}. Content: {content}")
|
||||
|
||||
def _is_node_ready(self, node_name: str) -> bool:
|
||||
"""Check if a node is ready to execute based on its parent nodes.
|
||||
If activation is any then execute as soon as any parent has finished
|
||||
If activation is all then execute only when all parents have finished
|
||||
"""
|
||||
node = self._graph.nodes[node_name]
|
||||
if node.activation == "any":
|
||||
return bool(self._pending_execution[node_name])
|
||||
return all(parent in self._pending_execution[node_name] for parent in self._parents[node_name])
|
||||
|
||||
async def _select_speakers(self, thread: List[BaseAgentEvent | BaseChatMessage], many: bool = True) -> List[str]:
|
||||
"""Select the next set of agents to execute based on DAG constraints."""
|
||||
next_speakers: Set[str] = set()
|
||||
source_node: DiGraphNode | None = None
|
||||
source: str | None = None
|
||||
|
||||
if thread and isinstance(thread[-1], BaseChatMessage):
|
||||
source = thread[-1].source # name of the agent that just finished
|
||||
content = thread[-1].to_model_text()
|
||||
|
||||
# Safety check: only an active node can send a response
|
||||
if source != "user":
|
||||
if source not in self._active_nodes:
|
||||
raise RuntimeError(f"Agent '{source}' is not currently active.")
|
||||
|
||||
# Mark the node as no longer active (it just finished)
|
||||
self._active_node_count[source] -= 1
|
||||
|
||||
if self._active_node_count[source] <= 0:
|
||||
self._active_nodes.remove(source)
|
||||
|
||||
source_node = self._graph.nodes[source]
|
||||
|
||||
if source_node.edges:
|
||||
# Case: conditional edges — only execute if condition is met
|
||||
target_nodes_names: List[str] = []
|
||||
if source_node.edges[0].condition is not None:
|
||||
target_nodes_names = [self._get_valid_target(source_node, content)]
|
||||
other_nodes = [
|
||||
edge.target for edge in source_node.edges if edge.target != target_nodes_names[0]
|
||||
]
|
||||
for other_node in other_nodes:
|
||||
other_active_parents = [
|
||||
parent
|
||||
for parent in self._parents[other_node]
|
||||
if (parent != source and parent in self._active_nodes)
|
||||
]
|
||||
if not other_active_parents:
|
||||
self._pending_execution.pop(other_node)
|
||||
else:
|
||||
self._pending_execution[other_node] = other_active_parents
|
||||
|
||||
else:
|
||||
# Case: unconditional edges — mark this source as completed for all its children
|
||||
target_nodes_names = [edge.target for edge in source_node.edges]
|
||||
|
||||
for target in target_nodes_names:
|
||||
self._pending_execution[target].append(source)
|
||||
# Propagate the update to the children of the node.
|
||||
for edge in self._edges[source]:
|
||||
if edge.condition and edge.condition not in content:
|
||||
continue
|
||||
if self._activation[edge.target] == "all":
|
||||
self._remaining[edge.target] -= 1
|
||||
if self._remaining[edge.target] == 0:
|
||||
# If all parents are done, add to the ready queue.
|
||||
self._ready.append(edge.target)
|
||||
else:
|
||||
# TODO: Check if there are any usecase where the User can decide on the next speaker
|
||||
pass
|
||||
# If activation is any, add to the ready queue if not already enqueued.
|
||||
if not self._enqueued_any[edge.target]:
|
||||
self._ready.append(edge.target)
|
||||
self._enqueued_any[edge.target] = True
|
||||
|
||||
# After updating _pending_execution, check which nodes are now unblocked
|
||||
for node_name in list(self._pending_execution):
|
||||
if self._use_default_start and not self._default_start_executed:
|
||||
if node_name == self._graph.default_start_node:
|
||||
next_speakers.add(node_name)
|
||||
self._default_start_executed = True
|
||||
break
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]:
|
||||
# Drain the ready queue for the next set of speakers.
|
||||
speakers: List[str] = []
|
||||
while self._ready:
|
||||
speaker = self._ready.popleft()
|
||||
speakers.append(speaker)
|
||||
# Reset the bookkeeping for the node that were selected.
|
||||
if self._activation[speaker] == "any":
|
||||
self._enqueued_any[speaker] = False
|
||||
else:
|
||||
self._remaining[speaker] = len(self._parents[speaker])
|
||||
|
||||
if self._is_node_ready(node_name):
|
||||
next_speakers.add(node_name)
|
||||
node = self._graph.nodes[node_name]
|
||||
if node.activation == "all":
|
||||
self._pending_execution.pop(node_name)
|
||||
else:
|
||||
# If activation is any, remove the parent that just finished
|
||||
if source is not None:
|
||||
self._pending_execution[node_name] = [
|
||||
parent for parent in self._pending_execution[node_name] if parent != source
|
||||
]
|
||||
|
||||
# If none of the other parents of this node are active, remove this node from pending execution
|
||||
node_parents = self._parents[node_name]
|
||||
if not any(parent in self._active_nodes for parent in node_parents):
|
||||
self._pending_execution.pop(node_name)
|
||||
|
||||
if not many:
|
||||
break
|
||||
|
||||
# Prepopulate children of next_speakers into _pending_execution
|
||||
for node_name in next_speakers:
|
||||
for edge in self._graph.nodes[node_name].edges:
|
||||
if edge.target not in self._pending_execution:
|
||||
self._pending_execution[edge.target] = []
|
||||
|
||||
# Mark newly selected speakers as active
|
||||
for speaker in next_speakers:
|
||||
if speaker not in self._active_nodes:
|
||||
self._active_nodes.add(speaker)
|
||||
|
||||
self._active_node_count[speaker] += 1
|
||||
|
||||
if not self._pending_execution and not next_speakers and not self._active_nodes:
|
||||
next_speakers = set([_DIGRAPH_STOP_AGENT_NAME]) # Call the termination agent
|
||||
|
||||
return list(next_speakers)
|
||||
|
||||
async def select_speakers(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> List[str]:
|
||||
return await self._select_speakers(thread)
|
||||
|
||||
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
|
||||
"""Select a speaker from the participants and return the
|
||||
topic type of the selected speaker."""
|
||||
speakers = await self._select_speakers(thread, many=False)
|
||||
# If there are no speakers, trigger the stop agent.
|
||||
if not speakers:
|
||||
raise RuntimeError("No available speakers found.")
|
||||
return speakers[0]
|
||||
speakers = [_DIGRAPH_STOP_AGENT_NAME]
|
||||
|
||||
return speakers
|
||||
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
@@ -365,10 +282,9 @@ class GraphFlowManager(BaseGroupChatManager):
|
||||
state = {
|
||||
"message_thread": [message.dump() for message in self._message_thread],
|
||||
"current_turn": self._current_turn,
|
||||
"active_nodes": list(self._active_nodes),
|
||||
"pending_execution": self._pending_execution,
|
||||
"active_node_count": self._active_node_count,
|
||||
"default_start_executed": self._default_start_executed,
|
||||
"remaining": dict(self._remaining),
|
||||
"enqueued_any": dict(self._enqueued_any),
|
||||
"ready": list(self._ready),
|
||||
}
|
||||
return state
|
||||
|
||||
@@ -376,10 +292,9 @@ class GraphFlowManager(BaseGroupChatManager):
|
||||
"""Restore execution state from saved data."""
|
||||
self._message_thread = [self._message_factory.create(msg) for msg in state["message_thread"]]
|
||||
self._current_turn = state["current_turn"]
|
||||
self._active_nodes = set(state["active_nodes"])
|
||||
self._pending_execution = state["pending_execution"]
|
||||
self._active_node_count = state["active_node_count"]
|
||||
self._default_start_executed = state.get("default_start_executed", False)
|
||||
self._remaining = Counter(state["remaining"])
|
||||
self._enqueued_any = state["enqueued_any"]
|
||||
self._ready = deque(state["ready"])
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset execution state to the start of the graph."""
|
||||
@@ -387,11 +302,9 @@ class GraphFlowManager(BaseGroupChatManager):
|
||||
self._message_thread.clear()
|
||||
if self._termination_condition:
|
||||
await self._termination_condition.reset()
|
||||
|
||||
self._active_nodes = set()
|
||||
self._active_node_count = {node: 0 for node in self._graph.nodes}
|
||||
self._pending_execution = {node: [] for node in self._start_nodes}
|
||||
self._default_start_executed = False
|
||||
self._remaining = Counter({n: len(p) for n, p in self._parents.items()})
|
||||
self._enqueued_any = {n: False for n in self._graph.nodes}
|
||||
self._ready = deque([n for n in self._graph.get_start_nodes()])
|
||||
|
||||
|
||||
class _StopAgent(BaseChatAgent):
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Mapping
|
||||
from typing import Any, Dict, List, Mapping, Sequence
|
||||
|
||||
from autogen_core import AgentId, CancellationToken, DefaultTopicId, MessageContext, event, rpc
|
||||
from autogen_core.models import (
|
||||
@@ -37,6 +37,7 @@ from .._events import (
|
||||
GroupChatReset,
|
||||
GroupChatStart,
|
||||
GroupChatTermination,
|
||||
SerializableException,
|
||||
)
|
||||
from ._prompts import (
|
||||
ORCHESTRATOR_FINAL_ANSWER_PROMPT,
|
||||
@@ -187,22 +188,29 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
|
||||
@event
|
||||
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: # type: ignore
|
||||
delta: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if message.agent_response.inner_messages is not None:
|
||||
for inner_message in message.agent_response.inner_messages:
|
||||
delta.append(inner_message)
|
||||
await self.update_message_thread([message.agent_response.chat_message])
|
||||
delta.append(message.agent_response.chat_message)
|
||||
try:
|
||||
delta: List[BaseAgentEvent | BaseChatMessage] = []
|
||||
if message.agent_response.inner_messages is not None:
|
||||
for inner_message in message.agent_response.inner_messages:
|
||||
delta.append(inner_message)
|
||||
await self.update_message_thread([message.agent_response.chat_message])
|
||||
delta.append(message.agent_response.chat_message)
|
||||
|
||||
if self._termination_condition is not None:
|
||||
stop_message = await self._termination_condition(delta)
|
||||
if stop_message is not None:
|
||||
# Reset the termination conditions.
|
||||
await self._termination_condition.reset()
|
||||
# Signal termination.
|
||||
await self._signal_termination(stop_message)
|
||||
return
|
||||
await self._orchestrate_step(ctx.cancellation_token)
|
||||
if self._termination_condition is not None:
|
||||
stop_message = await self._termination_condition(delta)
|
||||
if stop_message is not None:
|
||||
# Reset the termination conditions.
|
||||
await self._termination_condition.reset()
|
||||
# Signal termination.
|
||||
await self._signal_termination(stop_message)
|
||||
return
|
||||
|
||||
await self._orchestrate_step(ctx.cancellation_token)
|
||||
except Exception as e:
|
||||
error = SerializableException.from_exception(e)
|
||||
await self._signal_termination_with_error(error)
|
||||
# Raise the error to the runtime.
|
||||
raise
|
||||
|
||||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
|
||||
pass
|
||||
@@ -229,9 +237,9 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
self._n_rounds = orchestrator_state.n_rounds
|
||||
self._n_stalls = orchestrator_state.n_stalls
|
||||
|
||||
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Not used in this orchestrator, we select next speaker in _orchestrate_step."""
|
||||
return ""
|
||||
return [""]
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset the group chat manager."""
|
||||
@@ -275,7 +283,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
|
||||
# Broadcast
|
||||
await self.publish_message(
|
||||
GroupChatAgentResponse(agent_response=Response(chat_message=ledger_message)),
|
||||
GroupChatAgentResponse(agent_response=Response(chat_message=ledger_message), agent_name=self._name),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
)
|
||||
|
||||
@@ -389,7 +397,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
|
||||
# Broadcast it
|
||||
await self.publish_message( # Broadcast
|
||||
GroupChatAgentResponse(agent_response=Response(chat_message=message)),
|
||||
GroupChatAgentResponse(agent_response=Response(chat_message=message), agent_name=self._name),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
@@ -470,7 +478,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
|
||||
# Broadcast
|
||||
await self.publish_message(
|
||||
GroupChatAgentResponse(agent_response=Response(chat_message=message)),
|
||||
GroupChatAgentResponse(agent_response=Response(chat_message=message), agent_name=self._name),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, List, Mapping
|
||||
from typing import Any, Callable, List, Mapping, Sequence
|
||||
|
||||
from autogen_core import AgentRuntime, Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
@@ -69,8 +69,13 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
||||
self._current_turn = round_robin_state.current_turn
|
||||
self._next_speaker_index = round_robin_state.next_speaker_index
|
||||
|
||||
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
|
||||
"""Select a speaker from the participants in a round-robin fashion."""
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Select a speaker from the participants in a round-robin fashion.
|
||||
|
||||
.. note::
|
||||
|
||||
This method always returns a single speaker.
|
||||
"""
|
||||
current_speaker_index = self._next_speaker_index
|
||||
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_names)
|
||||
current_speaker = self._participant_names[current_speaker_index]
|
||||
|
||||
@@ -150,10 +150,14 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)]
|
||||
await self._add_messages_to_context(self._model_context, base_chat_messages)
|
||||
|
||||
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Selects the next speaker in a group chat using a ChatCompletion client,
|
||||
with the selector function as override if it returns a speaker name.
|
||||
|
||||
.. note::
|
||||
|
||||
This method always returns a single speaker name.
|
||||
|
||||
A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
|
||||
"""
|
||||
# Use the selector function if provided.
|
||||
@@ -171,7 +175,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
f"Expected one of: {self._participant_names}."
|
||||
)
|
||||
# Skip the model based selection.
|
||||
return speaker
|
||||
return [speaker]
|
||||
|
||||
# Use the candidate function to filter participants if provided
|
||||
if self._candidate_func is not None:
|
||||
@@ -211,7 +215,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
agent_name = participants[0]
|
||||
self._previous_speaker = agent_name
|
||||
trace_logger.debug(f"Selected speaker: {agent_name}")
|
||||
return agent_name
|
||||
return [agent_name]
|
||||
|
||||
def construct_message_history(self, message_history: List[LLMMessage]) -> str:
|
||||
# Construct the history of the conversation.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, List, Mapping
|
||||
from typing import Any, Callable, List, Mapping, Sequence
|
||||
|
||||
from autogen_core import AgentRuntime, Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
@@ -79,17 +79,22 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
await self._termination_condition.reset()
|
||||
self._current_speaker = self._participant_names[0]
|
||||
|
||||
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
|
||||
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
|
||||
"""Select a speaker from the participants based on handoff message.
|
||||
Looks for the last handoff message in the thread to determine the next speaker."""
|
||||
Looks for the last handoff message in the thread to determine the next speaker.
|
||||
|
||||
.. note::
|
||||
|
||||
This method always returns a single speaker.
|
||||
"""
|
||||
if len(thread) == 0:
|
||||
return self._current_speaker
|
||||
return [self._current_speaker]
|
||||
for message in reversed(thread):
|
||||
if isinstance(message, HandoffMessage):
|
||||
self._current_speaker = message.target
|
||||
# The latest handoff message should always target a valid participant.
|
||||
assert self._current_speaker in self._participant_names
|
||||
return self._current_speaker
|
||||
return [self._current_speaker]
|
||||
return self._current_speaker
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, List, Sequence, Set
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from typing import AsyncGenerator, List, Sequence
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
@@ -12,9 +12,8 @@ from autogen_agentchat.agents import (
|
||||
PerSourceFilter,
|
||||
)
|
||||
from autogen_agentchat.base import Response, TaskResult
|
||||
from autogen_agentchat.conditions import MaxMessageTermination
|
||||
from autogen_agentchat.conditions import MaxMessageTermination, SourceMatchTermination
|
||||
from autogen_agentchat.messages import BaseChatMessage, ChatMessage, MessageFactory, StopMessage, TextMessage
|
||||
from autogen_agentchat.messages import BaseTextChatMessage as TextChatMessage
|
||||
from autogen_agentchat.teams import (
|
||||
DiGraphBuilder,
|
||||
GraphFlow,
|
||||
@@ -269,61 +268,6 @@ def test_validate_graph_mixed_conditions() -> None:
|
||||
graph.graph_validate()
|
||||
|
||||
|
||||
def test_get_valid_target() -> None:
|
||||
node = DiGraphNode(
|
||||
name="A",
|
||||
edges=[DiGraphEdge(target="B", condition="approve"), DiGraphEdge(target="C", condition="reject")],
|
||||
)
|
||||
manager = GraphFlowManager.__new__(GraphFlowManager)
|
||||
|
||||
assert manager._get_valid_target(node, "please approve this") == "B" # pyright: ignore[reportPrivateUsage]
|
||||
assert manager._get_valid_target(node, "i reject this") == "C" # pyright: ignore[reportPrivateUsage]
|
||||
with pytest.raises(RuntimeError):
|
||||
manager._get_valid_target(node, "unknown path") # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def test_is_node_ready_all_and_any() -> None:
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="C")]),
|
||||
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
||||
"C": DiGraphNode(name="C", edges=[], activation="all"),
|
||||
}
|
||||
)
|
||||
|
||||
manager = GraphFlowManager.__new__(GraphFlowManager)
|
||||
manager._graph = graph # pyright: ignore[reportPrivateUsage]
|
||||
manager._parents = graph.get_parents() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# === Test "all" activation ===
|
||||
# Case 1: No parent finished
|
||||
manager._pending_execution = {"C": []} # pyright: ignore[reportPrivateUsage]
|
||||
assert not manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Case 2: One parent finished
|
||||
manager._pending_execution = {"C": ["A"]} # pyright: ignore[reportPrivateUsage]
|
||||
assert not manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Case 3: All parents finished
|
||||
manager._pending_execution = {"C": ["A", "B"]} # pyright: ignore[reportPrivateUsage]
|
||||
assert manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# === Test "any" activation ===
|
||||
graph.nodes["C"].activation = "any"
|
||||
|
||||
# Case 1: No parent finished
|
||||
manager._pending_execution = {"C": []} # pyright: ignore[reportPrivateUsage]
|
||||
assert not manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Case 2: One parent finished
|
||||
manager._pending_execution = {"C": ["B"]} # pyright: ignore[reportPrivateUsage]
|
||||
assert manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Case 3: All parents finished
|
||||
manager._pending_execution = {"C": ["A", "B"]} # pyright: ignore[reportPrivateUsage]
|
||||
assert manager._is_node_ready("C") # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_digraph_manager_cycle_without_termination() -> None:
|
||||
"""Test GraphManager raises error for cyclic graph without termination condition."""
|
||||
@@ -359,170 +303,6 @@ async def test_invalid_digraph_manager_cycle_without_termination() -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def digraph_manager() -> Callable[..., GraphFlowManager]:
|
||||
@patch(
|
||||
"autogen_agentchat.teams._group_chat._base_group_chat_manager.BaseGroupChatManager.__init__", return_value=None
|
||||
)
|
||||
def _create(
|
||||
_: Any,
|
||||
graph: DiGraph,
|
||||
active_nodes: Set[str] | None = None,
|
||||
thread: List[BaseAgentEvent | BaseChatMessage] | None = None,
|
||||
pending: Dict[str, List[str]] | None = None,
|
||||
) -> GraphFlowManager:
|
||||
manager = GraphFlowManager.__new__(GraphFlowManager)
|
||||
manager._graph = graph # pyright: ignore[reportPrivateUsage]
|
||||
manager._parents = graph.get_parents() # pyright: ignore[reportPrivateUsage]
|
||||
manager._start_nodes = graph.get_start_nodes() # pyright: ignore[reportPrivateUsage]
|
||||
manager._leaf_nodes = graph.get_leaf_nodes() # pyright: ignore[reportPrivateUsage]
|
||||
manager._active_nodes = set(active_nodes or []) # pyright: ignore[reportPrivateUsage]
|
||||
manager._active_node_count = {node: 0 for node in graph.nodes} # pyright: ignore[reportPrivateUsage]
|
||||
manager._message_factory = MessageFactory() # pyright: ignore[reportPrivateUsage]
|
||||
manager._message_thread = thread if thread is not None else [] # pyright: ignore[reportPrivateUsage]
|
||||
manager._pending_execution = pending if pending is not None else {node: [] for node in graph.get_start_nodes()} # pyright: ignore[reportPrivateUsage]
|
||||
manager._name = "test_manager" # pyright: ignore[reportPrivateUsage]
|
||||
manager._use_default_start = False # pyright: ignore[reportPrivateUsage]
|
||||
return manager
|
||||
|
||||
return _create
|
||||
|
||||
|
||||
# -------------------- Test: Sequential Flow --------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_speakers_linear(digraph_manager: Callable[..., GraphFlowManager]) -> None:
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
||||
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
||||
"C": DiGraphNode(name="C", edges=[]),
|
||||
}
|
||||
)
|
||||
message_thread = [TextChatMessage(source="A", content="done", metadata={})]
|
||||
manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []})
|
||||
|
||||
result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage]
|
||||
assert result == ["B"]
|
||||
assert "B" in manager._active_nodes # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
# -------------------- Test: Parallel Fan-out --------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_speakers_parallel(digraph_manager: Callable[..., GraphFlowManager]) -> None:
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
|
||||
"B": DiGraphNode(name="B", edges=[]),
|
||||
"C": DiGraphNode(name="C", edges=[]),
|
||||
}
|
||||
)
|
||||
message_thread = [TextChatMessage(source="A", content="done", metadata={})]
|
||||
manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []})
|
||||
|
||||
result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage]
|
||||
assert set(result) == {"B", "C"}
|
||||
assert "B" in manager._active_nodes # pyright: ignore[reportPrivateUsage]
|
||||
assert "C" in manager._active_nodes # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
# -------------------- Test: Conditional Path --------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_speakers_conditional(digraph_manager: Callable[..., GraphFlowManager]) -> None:
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(
|
||||
name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")]
|
||||
),
|
||||
"B": DiGraphNode(name="B", edges=[]),
|
||||
"C": DiGraphNode(name="C", edges=[]),
|
||||
}
|
||||
)
|
||||
message_thread = [TextChatMessage(source="A", content="no", metadata={})]
|
||||
manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []})
|
||||
|
||||
result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage]
|
||||
assert result == ["C"]
|
||||
assert "C" in manager._active_nodes # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_speakers_from_start_nodes(digraph_manager: Callable[..., GraphFlowManager]) -> None:
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[]),
|
||||
"B": DiGraphNode(name="B", edges=[]),
|
||||
}
|
||||
)
|
||||
# No prior message — both are start nodes
|
||||
manager = digraph_manager(graph=graph, active_nodes=set(), thread=[], pending={"A": [], "B": []})
|
||||
result = await manager.select_speakers([])
|
||||
assert set(result) == {"A", "B"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_speakers_termination(digraph_manager: Callable[..., GraphFlowManager]) -> None:
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[]),
|
||||
}
|
||||
)
|
||||
|
||||
# Create the manager and manually patch _signal_termination to track calls
|
||||
manager = digraph_manager(
|
||||
graph=graph, active_nodes={"A"}, thread=[TextChatMessage(source="A", content="done", metadata={})], pending={}
|
||||
)
|
||||
manager._signal_termination = AsyncMock() # type: ignore[assignment]
|
||||
|
||||
result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# No speakers left to run, so result should be empty
|
||||
assert result == [_DIGRAPH_STOP_AGENT_NAME]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_speakers_conditional_all_activation(
|
||||
digraph_manager: Callable[..., GraphFlowManager],
|
||||
) -> None:
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(
|
||||
name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")]
|
||||
),
|
||||
"B": DiGraphNode(name="B", edges=[], activation="all"),
|
||||
"C": DiGraphNode(name="C", edges=[], activation="all"),
|
||||
}
|
||||
)
|
||||
message_thread = [TextChatMessage(source="A", content="no", metadata={})]
|
||||
manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []})
|
||||
|
||||
result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage]
|
||||
assert result == ["C"]
|
||||
assert "C" in manager._active_nodes # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_speakers_conditional_any_activation(
|
||||
digraph_manager: Callable[..., GraphFlowManager],
|
||||
) -> None:
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(
|
||||
name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")]
|
||||
),
|
||||
"B": DiGraphNode(name="B", edges=[], activation="any"),
|
||||
"C": DiGraphNode(name="C", edges=[], activation="any"),
|
||||
}
|
||||
)
|
||||
message_thread = [TextChatMessage(source="A", content="yes", metadata={})]
|
||||
manager = digraph_manager(graph=graph, active_nodes={"A"}, thread=message_thread, pending={"B": [], "C": []})
|
||||
|
||||
result = await manager.select_speakers(manager._message_thread) # pyright: ignore[reportPrivateUsage]
|
||||
assert result == ["B"]
|
||||
assert "B" in manager._active_nodes # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
class _EchoAgent(BaseChatAgent):
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
super().__init__(name, description)
|
||||
@@ -1417,7 +1197,6 @@ async def test_graph_flow_serialize_deserialize() -> None:
|
||||
participants=builder.get_participants(),
|
||||
graph=builder.build(),
|
||||
runtime=None,
|
||||
termination_condition=MaxMessageTermination(5),
|
||||
)
|
||||
|
||||
serialized = team.dump_component()
|
||||
@@ -1439,11 +1218,52 @@ async def test_graph_flow_serialize_deserialize() -> None:
|
||||
assert results.messages[1].source == "A"
|
||||
assert results.messages[1].content == "0"
|
||||
assert isinstance(results.messages[2], TextMessage)
|
||||
assert results.messages[2].source == "A"
|
||||
assert results.messages[2].content == "1"
|
||||
assert isinstance(results.messages[3], TextMessage)
|
||||
assert results.messages[3].source == "B"
|
||||
assert results.messages[3].content == "0"
|
||||
assert results.messages[2].source == "B"
|
||||
assert results.messages[2].content == "0"
|
||||
assert isinstance(results.messages[-1], StopMessage)
|
||||
assert results.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
||||
assert results.messages[-1].content == "Digraph execution is complete"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_flow_stateful_pause_and_resume_with_termination() -> None:
|
||||
client_a = ReplayChatCompletionClient(["A1", "A2"])
|
||||
client_b = ReplayChatCompletionClient(["B1"])
|
||||
|
||||
a = AssistantAgent("A", model_client=client_a)
|
||||
b = AssistantAgent("B", model_client=client_b)
|
||||
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(a).add_node(b)
|
||||
builder.add_edge(a, b)
|
||||
builder.set_entry_point(a)
|
||||
|
||||
team = GraphFlow(
|
||||
participants=builder.get_participants(),
|
||||
graph=builder.build(),
|
||||
runtime=None,
|
||||
termination_condition=SourceMatchTermination(sources=["A"]),
|
||||
)
|
||||
|
||||
result = await team.run(task="Start")
|
||||
assert len(result.messages) == 2
|
||||
assert result.messages[0].source == "user"
|
||||
assert result.messages[1].source == "A"
|
||||
assert result.stop_reason is not None and result.stop_reason == "'A' answered"
|
||||
|
||||
# Export state.
|
||||
state = await team.save_state()
|
||||
|
||||
# Load state into a new team.
|
||||
new_team = GraphFlow(
|
||||
participants=builder.get_participants(),
|
||||
graph=builder.build(),
|
||||
runtime=None,
|
||||
)
|
||||
await new_team.load_state(state)
|
||||
|
||||
# Resume.
|
||||
result = await new_team.run()
|
||||
assert len(result.messages) == 2
|
||||
assert result.messages[0].source == "B"
|
||||
assert result.messages[1].source == _DIGRAPH_STOP_AGENT_NAME
|
||||
|
||||
Reference in New Issue
Block a user