Support for external agent runtime in AgentChat (#5843)

Resolves #4075

1. Introduce custom runtime parameter for all AgentChat teams
(RoundRobinGroupChat, SelectorGroupChat, etc.). This is done by making
sure each team's topics are isolated from other teams, and decoupling
state from agent identities. Also, I removed the closure agent from the
BaseGroupChat and use the group chat manager agent to relay messages to
the output message queue.
2. Added unit tests to test scenarios with custom runtimes by using
pytest fixture
3. Refactored existing unit tests to use ReplayChatCompletionClient with
a few improvements to the client.
4. Fix a one-liner bug in AssistantAgent that caused deserialized agent
to have handoffs.

How to use it? 

```python
import asyncio
from autogen_core import SingleThreadedAgentRuntime
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.conditions import TextMentionTermination
from autogen_ext.models.replay import ReplayChatCompletionClient

async def main() -> None:
    # Create a runtime
    runtime = SingleThreadedAgentRuntime()
    runtime.start()

    # Create a model client.
    model_client = ReplayChatCompletionClient(
        ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
    )

    # Create agents
    agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
    agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")

    # Create a termination condition
    termination_condition = TextMentionTermination("10", sources=["assistant1", "assistant2"])

    # Create a team
    team = RoundRobinGroupChat([agent1, agent2], runtime=runtime, termination_condition=termination_condition)

    # Run the team
    stream = team.run_stream(task="Count to 10.")
    async for message in stream:
        print(message)
    
    # Save the state.
    state = await team.save_state()

    # Load the state to an existing team.
    await team.load_state(state)

    # Run the team again
    model_client.reset()
    stream = team.run_stream(task="Count to 10.")
    async for message in stream:
        print(message)

    # Create a new team, with the same agent names.
    agent3 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
    agent4 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
    new_team = RoundRobinGroupChat([agent3, agent4], runtime=runtime, termination_condition=termination_condition)

    # Load the state to the new team.
    await new_team.load_state(state)

    # Run the new team
    model_client.reset()
    new_stream = new_team.run_stream(task="Count to 10.")
    async for message in new_stream:
        print(message)
    
    # Stop the runtime
    await runtime.stop()

asyncio.run(main())
```

TODOs as future PRs:
1. Documentation.
2. How to handle errors in custom runtime when the agent has exception?

---------

Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
This commit is contained in:
Eric Zhu
2025-03-06 10:32:52 -08:00
committed by GitHub
parent 30b1b8f90c
commit 7e5c1154cf
14 changed files with 775 additions and 1138 deletions

View File

@@ -1192,7 +1192,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
name=self.name,
model_client=self._model_client.dump_component(),
tools=[tool.dump_component() for tool in self._tools],
handoffs=list(self._handoffs.values()),
handoffs=list(self._handoffs.values()) if self._handoffs else None,
model_context=self._model_context.dump_component(),
memory=[memory.dump_component() for memory in self._memory] if self._memory else None,
description=self.description,

View File

@@ -29,7 +29,6 @@ class TeamState(BaseState):
"""State for a team of agents."""
agent_states: Mapping[str, Any] = Field(default_factory=dict)
team_id: str = Field(default="")
type: str = Field(default="TeamState")

View File

@@ -2,29 +2,32 @@ import asyncio
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, List, Mapping, Sequence
from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence
from autogen_core import (
AgentId,
AgentInstantiationContext,
AgentRuntime,
AgentType,
CancellationToken,
ClosureAgent,
ComponentBase,
MessageContext,
SingleThreadedAgentRuntime,
TypeSubscription,
)
from autogen_core._closure_agent import ClosureContext
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, ModelClientStreamingChunkEvent, TextMessage
from ...messages import (
AgentEvent,
BaseChatMessage,
ChatMessage,
ModelClientStreamingChunkEvent,
StopMessage,
TextMessage,
)
from ...state import TeamState
from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
from ._events import GroupChatReset, GroupChatStart, GroupChatTermination
from ._sequential_routed_agent import SequentialRoutedAgent
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
@@ -42,9 +45,11 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
def __init__(
self,
participants: List[ChatAgent],
group_chat_manager_name: str,
group_chat_manager_class: type[SequentialRoutedAgent],
termination_condition: TerminationCondition | None = None,
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
):
if len(participants) == 0:
raise ValueError("At least one participant is required.")
@@ -55,23 +60,42 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
self._termination_condition = termination_condition
self._max_turns = max_turns
# Constants for the group chat.
# The team ID is a UUID that is used to identify the team and its participants
# in the agent runtime. It is used to create unique topic types for each participant.
# Currently, team ID is binded to an object instance of the group chat class.
# So if you create two instances of group chat, there will be two teams with different IDs.
self._team_id = str(uuid.uuid4())
self._group_topic_type = "group_topic"
self._output_topic_type = "output_topic"
self._group_chat_manager_topic_type = "group_chat_manager"
self._participant_topic_types: List[str] = [participant.name for participant in participants]
self._participant_descriptions: List[str] = [participant.description for participant in participants]
self._collector_agent_type = "collect_output_messages"
# Constants for the closure agent to collect the output messages.
self._stop_reason: str | None = None
self._output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | None] = asyncio.Queue()
# Constants for the group chat team.
# The names are used to identify the agents within the team.
# The names may not be unique across different teams.
self._group_chat_manager_name = group_chat_manager_name
self._participant_names: List[str] = [participant.name for participant in participants]
self._participant_descriptions: List[str] = [participant.description for participant in participants]
# The group chat topic type is used for broadcast communication among all participants and the group chat manager.
self._group_topic_type = f"group_topic_{self._team_id}"
# The group chat manager topic type is used for direct communication with the group chat manager.
self._group_chat_manager_topic_type = f"{self._group_chat_manager_name}_{self._team_id}"
# The participant topic types are used for direct communication with each participant.
self._participant_topic_types: List[str] = [
f"{participant.name}_{self._team_id}" for participant in participants
]
# The output topic type is used for emitting streaming messages from the group chat.
# The group chat manager will relay the messages to the output message queue.
self._output_topic_type = f"output_topic_{self._team_id}"
# The queue for collecting the output messages.
self._output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination] = asyncio.Queue()
# Create a runtime for the team.
# TODO: The runtime should be created by a managed context.
# Background exceptions must not be ignored as it results in non-surfaced exceptions and early team termination.
self._runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
if runtime is not None:
self._runtime = runtime
self._embedded_runtime = False
else:
# Use a embedded single-threaded runtime for the group chat.
# Background exceptions must not be ignored as it results in non-surfaced exceptions and early team termination.
self._runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
self._embedded_runtime = True
# Flag to track if the group chat has been initialized.
self._initialized = False
@@ -82,10 +106,13 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
@abstractmethod
def _create_group_chat_manager_factory(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
) -> Callable[[], SequentialRoutedAgent]: ...
@@ -97,10 +124,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
agent: ChatAgent,
) -> Callable[[], ChatAgentContainer]:
def _factory() -> ChatAgentContainer:
id = AgentInstantiationContext.current_agent_id()
assert id == AgentId(type=agent.name, key=self._team_id)
container = ChatAgentContainer(parent_topic_type, output_topic_type, agent)
assert container.id == id
return container
return _factory
@@ -110,9 +134,8 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
group_chat_manager_agent_type = AgentType(self._group_chat_manager_topic_type)
# Register participants.
for participant, participant_topic_type in zip(self._participants, self._participant_topic_types, strict=False):
# Use the participant topic type as the agent type.
agent_type = participant_topic_type
# Use the participant topic type as the agent type.
for participant, agent_type in zip(self._participants, self._participant_topic_types, strict=True):
# Register the participant factory.
await ChatAgentContainer.register(
runtime,
@@ -120,7 +143,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
factory=self._create_participant_factory(self._group_topic_type, self._output_topic_type, participant),
)
# Add subscriptions for the participant.
await runtime.add_subscription(TypeSubscription(topic_type=participant_topic_type, agent_type=agent_type))
# The participant should be able to receive messages from its own topic.
await runtime.add_subscription(TypeSubscription(topic_type=agent_type, agent_type=agent_type))
# The participant should be able to receive messages from the group topic.
await runtime.add_subscription(TypeSubscription(topic_type=self._group_topic_type, agent_type=agent_type))
# Register the group chat manager.
@@ -128,50 +153,33 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
runtime,
type=group_chat_manager_agent_type.type,
factory=self._create_group_chat_manager_factory(
name=self._group_chat_manager_name,
group_topic_type=self._group_topic_type,
output_topic_type=self._output_topic_type,
participant_names=self._participant_names,
participant_topic_types=self._participant_topic_types,
participant_descriptions=self._participant_descriptions,
output_message_queue=self._output_message_queue,
termination_condition=self._termination_condition,
max_turns=self._max_turns,
),
)
# Add subscriptions for the group chat manager.
# The group chat manager should be able to receive messages from the its own topic.
await runtime.add_subscription(
TypeSubscription(
topic_type=self._group_chat_manager_topic_type, agent_type=group_chat_manager_agent_type.type
)
)
# The group chat manager should be able to receive messages from the group topic.
await runtime.add_subscription(
TypeSubscription(topic_type=self._group_topic_type, agent_type=group_chat_manager_agent_type.type)
)
async def collect_output_messages(
_runtime: ClosureContext,
message: GroupChatStart | GroupChatMessage | GroupChatTermination,
ctx: MessageContext,
) -> None:
"""Collect output messages from the group chat."""
if isinstance(message, GroupChatStart):
if message.messages is not None:
for msg in message.messages:
event_logger.info(msg)
await self._output_message_queue.put(msg)
elif isinstance(message, GroupChatMessage):
event_logger.info(message.message)
await self._output_message_queue.put(message.message)
elif isinstance(message, GroupChatTermination):
event_logger.info(message.message)
self._stop_reason = message.message.content
await ClosureAgent.register_closure(
runtime,
type=self._collector_agent_type,
closure=collect_output_messages,
subscriptions=lambda: [
TypeSubscription(topic_type=self._output_topic_type, agent_type=self._collector_agent_type),
],
# The group chat manager will relay the messages from output topic to the output message queue.
await runtime.add_subscription(
TypeSubscription(topic_type=self._output_topic_type, agent_type=group_chat_manager_agent_type.type)
)
self._initialized = True
async def run(
@@ -400,26 +408,40 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
raise ValueError("The team is already running, it cannot run again until it is stopped.")
self._is_running = True
# Start the runtime.
# TODO: The runtime should be started by a managed context.
self._runtime.start()
if self._embedded_runtime:
# Start the embedded runtime.
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
self._runtime.start()
if not self._initialized:
await self._init(self._runtime)
# Start a coroutine to stop the runtime and signal the output message queue is complete.
async def stop_runtime() -> None:
try:
await self._runtime.stop_when_idle()
finally:
await self._output_message_queue.put(None)
shutdown_task: asyncio.Task[None] | None = None
if self._embedded_runtime:
shutdown_task = asyncio.create_task(stop_runtime())
async def stop_runtime() -> None:
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
try:
# This will propagate any exceptions raised.
await self._runtime.stop_when_idle()
finally:
# Stop the consumption of messages and end the stream.
# NOTE: we also need to put a GroupChatTermination event here because when the group chat
# has an exception, the group chat manager may not be able to put a GroupChatTermination event in the queue.
await self._output_message_queue.put(
GroupChatTermination(
message=StopMessage(content="Exception occurred.", source=self._group_chat_manager_name)
)
)
# Create a background task to stop the runtime when the group chat
# is stopped or has an exception.
shutdown_task = asyncio.create_task(stop_runtime())
try:
# Run the team by sending the start message to the group chat manager.
# The group chat manager will start the group chat by relaying the message to the participants
# and the closure agent.
# and the group chat manager.
await self._runtime.send_message(
GroupChatStart(messages=messages),
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
@@ -427,6 +449,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
)
# Collect the output messages in order.
output_messages: List[AgentEvent | ChatMessage] = []
stop_reason: str | None = None
# Yield the messsages until the queue is empty.
while True:
message_future = asyncio.ensure_future(self._output_message_queue.get())
@@ -434,7 +457,13 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
cancellation_token.link_future(message_future)
# Wait for the next message, this will raise an exception if the task is cancelled.
message = await message_future
if message is None:
if isinstance(message, GroupChatTermination):
# If the message is None, it means the group chat has terminated.
# TODO: how do we handle termination when the runtime is not embedded
# and there is an exception in the group chat?
# The group chat manager may not be able to put a GroupChatTermination event in the queue,
# and this loop will never end.
stop_reason = message.message.content
break
yield message
if isinstance(message, ModelClientStreamingChunkEvent):
@@ -443,14 +472,14 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
output_messages.append(message)
# Yield the final result.
yield TaskResult(messages=output_messages, stop_reason=self._stop_reason)
yield TaskResult(messages=output_messages, stop_reason=stop_reason)
finally:
# Wait for the shutdown task to finish.
try:
# This will propagate any exceptions raised in the shutdown task.
# We need to ensure we cleanup though.
await shutdown_task
if shutdown_task is not None:
# Wait for the shutdown task to finish.
# This will propagate any exceptions raised.
await shutdown_task
finally:
# Clear the output message queue.
while not self._output_message_queue.empty():
@@ -506,8 +535,10 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
raise RuntimeError("The group chat is currently running. It must be stopped before it can be reset.")
self._is_running = True
# Start the runtime.
self._runtime.start()
if self._embedded_runtime:
# Start the runtime.
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
self._runtime.start()
try:
# Send a reset messages to all participants.
@@ -522,11 +553,12 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
)
finally:
# Stop the runtime.
await self._runtime.stop_when_idle()
if self._embedded_runtime:
# Stop the runtime.
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
await self._runtime.stop_when_idle()
# Reset the output message queue.
self._stop_reason = None
while not self._output_message_queue.empty():
self._output_message_queue.get_nowait()
@@ -534,7 +566,32 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
self._is_running = False
async def save_state(self) -> Mapping[str, Any]:
"""Save the state of the group chat team."""
"""Save the state of the group chat team.
The state is saved by calling the :meth:`~autogen_core.AgentRuntime.agent_save_state` method
on each participant and the group chat manager with their internal agent ID.
The state is returned as a nested dictionary: a dictionary with key `agent_states`,
which is a dictionary the agent names as keys and the state as values.
.. code-block:: text
{
"agent_states": {
"agent1": ...,
"agent2": ...,
"RoundRobinGroupChatManager": ...
}
}
.. note::
Starting v0.4.9, the state is using the agent name as the key instead of the agent ID,
and the `team_id` field is removed from the state. This is to allow the state to be
portable across different teams and runtimes. States saved with the old format
may not be compatible with the new format in the future.
"""
if not self._initialized:
raise RuntimeError("The group chat has not been initialized. It must be run before it can be saved.")
@@ -543,15 +600,31 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
self._is_running = True
try:
# Save the state of the runtime. This will save the state of the participants and the group chat manager.
agent_states = await self._runtime.save_state()
return TeamState(agent_states=agent_states, team_id=self._team_id).model_dump()
# Store state of each agent by their name.
# NOTE: we don't use the agent ID as the key here because we need to be able to decouple
# the state of the agents from their identities in the agent runtime.
agent_states: Dict[str, Mapping[str, Any]] = {}
# Save the state of all participants.
for name, agent_type in zip(self._participant_names, self._participant_topic_types, strict=True):
agent_id = AgentId(type=agent_type, key=self._team_id)
# NOTE: We are using the runtime's save state method rather than the agent instance's
# save_state method because we want to support saving state of remote agents.
agent_states[name] = await self._runtime.agent_save_state(agent_id)
# Save the state of the group chat manager.
agent_id = AgentId(type=self._group_chat_manager_topic_type, key=self._team_id)
agent_states[self._group_chat_manager_name] = await self._runtime.agent_save_state(agent_id)
return TeamState(agent_states=agent_states).model_dump()
finally:
# Indicate that the team is no longer running.
self._is_running = False
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the group chat team."""
"""Load an external state and overwrite the current state of the group chat team.
The state is loaded by calling the :meth:`~autogen_core.AgentRuntime.agent_load_state` method
on each participant and the group chat manager with their internal agent ID.
See :meth:`~autogen_agentchat.teams.BaseGroupChat.save_state` for the expected format of the state.
"""
if not self._initialized:
await self._init(self._runtime)
@@ -560,10 +633,25 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
self._is_running = True
try:
# Load the state of the runtime. This will load the state of the participants and the group chat manager.
team_state = TeamState.model_validate(state)
self._team_id = team_state.team_id
await self._runtime.load_state(team_state.agent_states)
# Load the state of all participants.
for name, agent_type in zip(self._participant_names, self._participant_topic_types, strict=True):
agent_id = AgentId(type=agent_type, key=self._team_id)
if name not in team_state.agent_states:
raise ValueError(f"Agent state for {name} not found in the saved state.")
await self._runtime.agent_load_state(agent_id, team_state.agent_states[name])
# Load the state of the group chat manager.
agent_id = AgentId(type=self._group_chat_manager_topic_type, key=self._team_id)
if self._group_chat_manager_name not in team_state.agent_states:
raise ValueError(f"Agent state for {self._group_chat_manager_name} not found in the saved state.")
await self._runtime.agent_load_state(agent_id, team_state.agent_states[self._group_chat_manager_name])
except ValidationError as e:
raise ValueError(
"Invalid state format. The expected state format has changed since v0.4.9. "
"Please read the release note on GitHub."
) from e
finally:
# Indicate that the team is no longer running.
self._is_running = False

View File

@@ -8,6 +8,7 @@ from ...base import TerminationCondition
from ...messages import AgentEvent, ChatMessage, StopMessage
from ._events import (
GroupChatAgentResponse,
GroupChatMessage,
GroupChatRequestPublish,
GroupChatReset,
GroupChatStart,
@@ -30,14 +31,18 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
def __init__(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None = None,
max_turns: int | None = None,
):
super().__init__(description="Group chat manager")
self._name = name
self._group_topic_type = group_topic_type
self._output_topic_type = output_topic_type
if len(participant_topic_types) != len(participant_descriptions):
@@ -46,9 +51,13 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
raise ValueError("The participant topic types must be unique.")
if group_topic_type in participant_topic_types:
raise ValueError("The group topic type must not be in the participant topic types.")
self._participant_topic_types = participant_topic_types
self._participant_names = participant_names
self._participant_name_to_topic_type = {
name: topic_type for name, topic_type in zip(participant_names, participant_topic_types, strict=True)
}
self._participant_descriptions = participant_descriptions
self._message_thread: List[AgentEvent | ChatMessage] = []
self._output_message_queue = output_message_queue
self._termination_condition = termination_condition
if max_turns is not None and max_turns <= 0:
raise ValueError("The maximum number of turns must be greater than 0.")
@@ -62,11 +71,11 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
# Check if the conversation has already terminated.
if self._termination_condition is not None and self._termination_condition.terminated:
early_stop_message = StopMessage(
content="The group chat has already terminated.", source="Group chat manager"
)
await self.publish_message(
GroupChatTermination(message=early_stop_message), topic_id=DefaultTopicId(type=self._output_topic_type)
content="The group chat has already terminated.",
source=self._name,
)
# Signal termination to the caller of the team.
await self._signal_termination(early_stop_message)
# Stop the group chat.
return
@@ -76,8 +85,11 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
if message.messages is not None:
# Log all messages at once
await self.publish_message(
GroupChatStart(messages=message.messages), topic_id=DefaultTopicId(type=self._output_topic_type)
GroupChatStart(messages=message.messages),
topic_id=DefaultTopicId(type=self._output_topic_type),
)
for msg in message.messages:
await self._output_message_queue.put(msg)
# Relay all messages at once to participants
await self.publish_message(
@@ -93,19 +105,21 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
if self._termination_condition is not None:
stop_message = await self._termination_condition(message.messages)
if stop_message is not None:
await self.publish_message(
GroupChatTermination(message=stop_message),
topic_id=DefaultTopicId(type=self._output_topic_type),
)
# Stop the group chat and reset the termination condition.
# Reset the termination condition.
await self._termination_condition.reset()
# Signal termination to the caller of the team.
await self._signal_termination(stop_message)
# Stop the group chat.
return
# Select a speaker to start/continue the conversation
speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_topic_type_future)
speaker_topic_type = await speaker_topic_type_future
ctx.cancellation_token.link_future(speaker_name_future)
speaker_name = await speaker_name_future
if speaker_name not in self._participant_name_to_topic_type:
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=speaker_topic_type),
@@ -127,12 +141,12 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
if self._termination_condition is not None:
stop_message = await self._termination_condition(delta)
if stop_message is not None:
await self.publish_message(
GroupChatTermination(message=stop_message), topic_id=DefaultTopicId(type=self._output_topic_type)
)
# Stop the group chat and reset the termination conditions and turn count.
# Reset the termination conditions and turn count.
await self._termination_condition.reset()
self._current_turn = 0
# Signal termination to the caller of the team.
await self._signal_termination(stop_message)
# Stop the group chat.
return
# Increment the turn count.
@@ -142,28 +156,45 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
if self._current_turn >= self._max_turns:
stop_message = StopMessage(
content=f"Maximum number of turns {self._max_turns} reached.",
source="Group chat manager",
source=self._name,
)
await self.publish_message(
GroupChatTermination(message=stop_message), topic_id=DefaultTopicId(type=self._output_topic_type)
)
# Stop the group chat and reset the termination conditions and turn count.
# Reset the termination conditions and turn count.
if self._termination_condition is not None:
await self._termination_condition.reset()
self._current_turn = 0
# Signal termination to the caller of the team.
await self._signal_termination(stop_message)
# Stop the group chat.
return
# Select a speaker to continue the conversation.
speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_topic_type_future)
speaker_topic_type = await speaker_topic_type_future
ctx.cancellation_token.link_future(speaker_name_future)
speaker_name = await speaker_name_future
if speaker_name not in self._participant_name_to_topic_type:
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=ctx.cancellation_token,
)
async def _signal_termination(self, message: StopMessage) -> None:
termination_event = GroupChatTermination(message=message)
# Log the early stop message.
await self.publish_message(
termination_event,
topic_id=DefaultTopicId(type=self._output_topic_type),
)
# Put the termination event in the output message queue.
await self._output_message_queue.put(termination_event)
@event
async def handle_group_chat_message(self, message: GroupChatMessage, ctx: MessageContext) -> None:
await self._output_message_queue.put(message.message)
@rpc
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
# Reset the group chat manager.

View File

@@ -1,14 +1,17 @@
import asyncio
import logging
from typing import Callable, List
from autogen_core import Component, ComponentModel
from autogen_core import AgentRuntime, Component, ComponentModel
from autogen_core.models import ChatCompletionClient
from pydantic import BaseModel
from typing_extensions import Self
from .... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
from ....base import ChatAgent, TerminationCondition
from ....messages import AgentEvent, ChatMessage
from .._base_group_chat import BaseGroupChat
from .._events import GroupChatTermination
from ._magentic_one_orchestrator import MagenticOneOrchestrator
from ._prompts import ORCHESTRATOR_FINAL_ANSWER_PROMPT
@@ -97,14 +100,17 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
*,
termination_condition: TerminationCondition | None = None,
max_turns: int | None = 20,
runtime: AgentRuntime | None = None,
max_stalls: int = 3,
final_answer_prompt: str = ORCHESTRATOR_FINAL_ANSWER_PROMPT,
):
super().__init__(
participants,
group_chat_manager_name="MagenticOneOrchestrator",
group_chat_manager_class=MagenticOneOrchestrator,
termination_condition=termination_condition,
max_turns=max_turns,
runtime=runtime,
)
# Validate the participants.
@@ -116,22 +122,28 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
def _create_group_chat_manager_factory(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
) -> Callable[[], MagenticOneOrchestrator]:
return lambda: MagenticOneOrchestrator(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
max_turns,
self._model_client,
self._max_stalls,
self._final_answer_prompt,
output_message_queue,
termination_condition,
)

View File

@@ -1,3 +1,4 @@
import asyncio
import json
import logging
import re
@@ -53,28 +54,33 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
def __init__(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
max_turns: int | None,
model_client: ChatCompletionClient,
max_stalls: int,
final_answer_prompt: str,
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
):
super().__init__(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
output_message_queue,
termination_condition,
max_turns,
)
self._model_client = model_client
self._max_stalls = max_stalls
self._final_answer_prompt = final_answer_prompt
self._name = "MagenticOneOrchestrator"
self._max_json_retries = 10
self._task = ""
self._facts = ""
@@ -84,7 +90,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
# Produce a team description. Each agent sould appear on a single line.
self._team_description = ""
for topic_type, description in zip(self._participant_topic_types, self._participant_descriptions, strict=True):
for topic_type, description in zip(self._participant_names, self._participant_descriptions, strict=True):
self._team_description += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
self._team_description = self._team_description.strip()
@@ -122,9 +128,8 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
# Check if the conversation has already terminated.
if self._termination_condition is not None and self._termination_condition.terminated:
early_stop_message = StopMessage(content="The group chat has already terminated.", source=self._name)
await self.publish_message(
GroupChatTermination(message=early_stop_message), topic_id=DefaultTopicId(type=self._output_topic_type)
)
# Signal termination.
await self._signal_termination(early_stop_message)
# Stop the group chat.
return
assert message is not None and message.messages is not None
@@ -132,8 +137,12 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
# Validate the group state given all the messages.
await self.validate_group_state(message.messages)
# Log the message.
# Log the message to the output topic.
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))
# Log the message to the output queue.
for msg in message.messages:
await self._output_message_queue.put(msg)
# Outer Loop for first time
# Create the initial task ledger
#################################
@@ -182,11 +191,10 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
if self._termination_condition is not None:
stop_message = await self._termination_condition(delta)
if stop_message is not None:
await self.publish_message(
GroupChatTermination(message=stop_message), topic_id=DefaultTopicId(type=self._output_topic_type)
)
# Stop the group chat and reset the termination conditions and turn count.
# Reset the termination conditions.
await self._termination_condition.reset()
# Signal termination.
await self._signal_termination(stop_message)
return
await self._orchestrate_step(ctx.cancellation_token)
@@ -233,7 +241,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
async def _reenter_outer_loop(self, cancellation_token: CancellationToken) -> None:
"""Re-enter Outer loop of the orchestrator after creating task ledger."""
# Reset the agents
for participant_topic_type in self._participant_topic_types:
for participant_topic_type in self._participant_name_to_topic_type.values():
await self._runtime.send_message(
GroupChatReset(),
recipient=AgentId(type=participant_topic_type, key=self.id.key),
@@ -251,11 +259,13 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
# Save my copy
self._message_thread.append(ledger_message)
# Log it
# Log it to the output topic.
await self.publish_message(
GroupChatMessage(message=ledger_message),
topic_id=DefaultTopicId(type=self._output_topic_type),
)
# Log it to the output queue.
await self._output_message_queue.put(ledger_message)
# Broadcast
await self.publish_message(
@@ -278,7 +288,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
context = self._thread_to_context()
progress_ledger_prompt = self._get_progress_ledger_prompt(
self._task, self._team_description, self._participant_topic_types
self._task, self._team_description, self._participant_names
)
context.append(UserMessage(content=progress_ledger_prompt, source=self._name))
progress_ledger: Dict[str, Any] = {}
@@ -292,10 +302,10 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
progress_ledger = json.loads(ledger_str)
# If the team consists of a single agent, deterministically set the next speaker
if len(self._participant_topic_types) == 1:
if len(self._participant_names) == 1:
progress_ledger["next_speaker"] = {
"reason": "The team consists of only one agent.",
"answer": self._participant_topic_types[0],
"answer": self._participant_names[0],
}
# Validate the structure
@@ -321,7 +331,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
# Validate the next speaker if the task is not yet complete
if (
not progress_ledger["is_request_satisfied"]["answer"]
and progress_ledger["next_speaker"]["answer"] not in self._participant_topic_types
and progress_ledger["next_speaker"]["answer"] not in self._participant_names
):
key_error = True
break
@@ -362,12 +372,14 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
message = TextMessage(content=progress_ledger["instruction_or_question"]["answer"], source=self._name)
self._message_thread.append(message) # My copy
# Log it
await self._log_message(f"Next Speaker: {progress_ledger['next_speaker']['answer']}")
# Log it to the output topic.
await self.publish_message(
GroupChatMessage(message=message),
topic_id=DefaultTopicId(type=self._output_topic_type),
)
# Log it to the output queue.
await self._output_message_queue.put(message)
# Broadcast it
await self.publish_message( # Broadcast
@@ -377,21 +389,18 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
)
# Request that the step be completed
valid_next_speaker: bool = False
next_speaker = progress_ledger["next_speaker"]["answer"]
for participant_topic_type in self._participant_topic_types:
if participant_topic_type == next_speaker:
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=next_speaker),
cancellation_token=cancellation_token,
)
valid_next_speaker = True
break
if not valid_next_speaker:
# Check if the next speaker is valid
if next_speaker not in self._participant_name_to_topic_type:
raise ValueError(
f"Invalid next speaker: {next_speaker} from the ledger, participants are: {self._participant_topic_types}"
f"Invalid next speaker: {next_speaker} from the ledger, participants are: {self._participant_names}"
)
participant_topic_type = self._participant_name_to_topic_type[next_speaker]
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=participant_topic_type),
cancellation_token=cancellation_token,
)
async def _update_task_ledger(self, cancellation_token: CancellationToken) -> None:
"""Update the task ledger (outer loop) with the latest facts and plan."""
@@ -436,11 +445,13 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
self._message_thread.append(message) # My copy
# Log it
# Log it to the output topic.
await self.publish_message(
GroupChatMessage(message=message),
topic_id=DefaultTopicId(type=self._output_topic_type),
)
# Log it to the output queue.
await self._output_message_queue.put(message)
# Broadcast
await self.publish_message(
@@ -449,13 +460,10 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
cancellation_token=cancellation_token,
)
# Signal termination
await self.publish_message(
GroupChatTermination(message=StopMessage(content=reason, source=self._name)),
topic_id=DefaultTopicId(type=self._output_topic_type),
)
if self._termination_condition is not None:
await self._termination_condition.reset()
# Signal termination
await self._signal_termination(StopMessage(content=reason, source=self._name))
def _thread_to_context(self) -> List[LLMMessage]:
"""Convert the message thread to a context for the model."""

View File

@@ -1,6 +1,7 @@
import asyncio
from typing import Any, Callable, List, Mapping
from autogen_core import Component, ComponentModel
from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel
from typing_extensions import Self
@@ -9,6 +10,7 @@ from ...messages import AgentEvent, ChatMessage
from ...state import RoundRobinManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
from ._events import GroupChatTermination
class RoundRobinGroupChatManager(BaseGroupChatManager):
@@ -16,18 +18,24 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
def __init__(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None = None,
) -> None:
super().__init__(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
output_message_queue,
termination_condition,
max_turns,
)
@@ -60,8 +68,8 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
"""Select a speaker from the participants in a round-robin fashion."""
current_speaker_index = self._next_speaker_index
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_topic_types)
current_speaker = self._participant_topic_types[current_speaker_index]
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_names)
current_speaker = self._participant_names[current_speaker_index]
return current_speaker
@@ -148,34 +156,45 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
component_config_schema = RoundRobinGroupChatConfig
component_provider_override = "autogen_agentchat.teams.RoundRobinGroupChat"
# TODO: Add * to the constructor to separate the positional parameters from the kwargs.
# This may be a breaking change so let's wait until a good time to do it.
def __init__(
self,
participants: List[ChatAgent],
termination_condition: TerminationCondition | None = None,
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
) -> None:
super().__init__(
participants,
group_chat_manager_name="RoundRobinGroupChatManager",
group_chat_manager_class=RoundRobinGroupChatManager,
termination_condition=termination_condition,
max_turns=max_turns,
runtime=runtime,
)
def _create_group_chat_manager_factory(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
) -> Callable[[], RoundRobinGroupChatManager]:
def _factory() -> RoundRobinGroupChatManager:
return RoundRobinGroupChatManager(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
output_message_queue,
termination_condition,
max_turns,
)

View File

@@ -1,8 +1,9 @@
import asyncio
import logging
import re
from typing import Any, Callable, Dict, List, Mapping, Sequence
from autogen_core import Component, ComponentModel
from autogen_core import AgentRuntime, Component, ComponentModel
from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage
from pydantic import BaseModel
from typing_extensions import Self
@@ -19,6 +20,7 @@ from ...messages import (
from ...state import SelectorManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
from ._events import GroupChatTermination
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
@@ -29,10 +31,13 @@ class SelectorGroupChatManager(BaseGroupChatManager):
def __init__(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
model_client: ChatCompletionClient,
@@ -42,10 +47,13 @@ class SelectorGroupChatManager(BaseGroupChatManager):
max_selector_attempts: int,
) -> None:
super().__init__(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
output_message_queue,
termination_condition,
max_turns,
)
@@ -91,6 +99,11 @@ class SelectorGroupChatManager(BaseGroupChatManager):
if self._selector_func is not None:
speaker = self._selector_func(thread)
if speaker is not None:
if speaker not in self._participant_names:
raise ValueError(
f"Selector function returned an invalid speaker name: {speaker}. "
f"Expected one of: {self._participant_names}."
)
# Skip the model based selection.
return speaker
@@ -100,7 +113,6 @@ class SelectorGroupChatManager(BaseGroupChatManager):
if isinstance(msg, BaseAgentEvent):
# Ignore agent events.
continue
# The agent type must be the same as the topic type, which we use as the agent name.
message = f"{msg.source}:"
if isinstance(msg.content, str):
message += f" {msg.content}"
@@ -117,18 +129,18 @@ class SelectorGroupChatManager(BaseGroupChatManager):
) # Create some consistency for how messages are separated in the transcript
history = "\n".join(history_messages)
# Construct agent roles, we are using the participant topic type as the agent name.
# Construct agent roles.
# Each agent sould appear on a single line.
roles = ""
for topic_type, description in zip(self._participant_topic_types, self._participant_descriptions, strict=True):
for topic_type, description in zip(self._participant_names, self._participant_descriptions, strict=True):
roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
roles = roles.strip()
# Construct agent list to be selected, skip the previous speaker if not allowed.
# Construct the candidate agent list to be selected from, skip the previous speaker if not allowed.
if self._previous_speaker is not None and not self._allow_repeated_speaker:
participants = [p for p in self._participant_topic_types if p != self._previous_speaker]
participants = [p for p in self._participant_names if p != self._previous_speaker]
else:
participants = self._participant_topic_types
participants = list(self._participant_names)
assert len(participants) > 0
# Select the next speaker.
@@ -157,7 +169,9 @@ class SelectorGroupChatManager(BaseGroupChatManager):
response = await self._model_client.create(messages=select_speaker_messages)
assert isinstance(response.content, str)
select_speaker_messages.append(AssistantMessage(content=response.content, source="selector"))
mentions = self._mentioned_agents(response.content, self._participant_topic_types)
# NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed.
# This is because the model may still select the previous speaker, and we want to catch that.
mentions = self._mentioned_agents(response.content, self._participant_names)
if len(mentions) == 0:
trace_logger.debug(f"Model failed to select a valid name: {response.content} (attempt {num_attempts})")
feedback = f"No valid name was mentioned. Please select from: {str(participants)}."
@@ -391,6 +405,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
*,
termination_condition: TerminationCondition | None = None,
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
selector_prompt: str = """You are in a role play game. The following roles are available:
{roles}.
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
@@ -405,9 +420,11 @@ Read the above conversation. Then select the next role from {participants} to pl
):
super().__init__(
participants,
group_chat_manager_name="SelectorGroupChatManager",
group_chat_manager_class=SelectorGroupChatManager,
termination_condition=termination_condition,
max_turns=max_turns,
runtime=runtime,
)
# Validate the participants.
if len(participants) < 2:
@@ -420,18 +437,24 @@ Read the above conversation. Then select the next role from {participants} to pl
def _create_group_chat_manager_factory(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
) -> Callable[[], BaseGroupChatManager]:
return lambda: SelectorGroupChatManager(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
output_message_queue,
termination_condition,
max_turns,
self._model_client,

View File

@@ -1,6 +1,7 @@
import asyncio
from typing import Any, Callable, List, Mapping
from autogen_core import Component, ComponentModel
from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel
from ...base import ChatAgent, TerminationCondition
@@ -8,6 +9,7 @@ from ...messages import AgentEvent, ChatMessage, HandoffMessage
from ...state import SwarmManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
from ._events import GroupChatTermination
class SwarmGroupChatManager(BaseGroupChatManager):
@@ -15,22 +17,28 @@ class SwarmGroupChatManager(BaseGroupChatManager):
def __init__(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
) -> None:
super().__init__(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
output_message_queue,
termination_condition,
max_turns,
)
self._current_speaker = participant_topic_types[0]
self._current_speaker = self._participant_names[0]
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
"""Validate the start messages for the group chat."""
@@ -38,9 +46,9 @@ class SwarmGroupChatManager(BaseGroupChatManager):
if messages:
for message in messages:
if isinstance(message, HandoffMessage):
if message.target not in self._participant_topic_types:
if message.target not in self._participant_names:
raise ValueError(
f"The target {message.target} is not one of the participants {self._participant_topic_types}. "
f"The target {message.target} is not one of the participants {self._participant_names}. "
"If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target."
)
return
@@ -48,9 +56,9 @@ class SwarmGroupChatManager(BaseGroupChatManager):
# Check if there is a handoff message in the thread that is not targeting a valid participant.
for existing_message in reversed(self._message_thread):
if isinstance(existing_message, HandoffMessage):
if existing_message.target not in self._participant_topic_types:
if existing_message.target not in self._participant_names:
raise ValueError(
f"The existing handoff target {existing_message.target} is not one of the participants {self._participant_topic_types}. "
f"The existing handoff target {existing_message.target} is not one of the participants {self._participant_names}. "
"If you are resuming Swarm with a new task make sure to include in your task "
"a HandoffMessage with a valid participant as the target. For example, if you are "
"resuming from a HandoffTermination, make sure the new task is a HandoffMessage "
@@ -65,7 +73,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
self._message_thread.clear()
if self._termination_condition is not None:
await self._termination_condition.reset()
self._current_speaker = self._participant_topic_types[0]
self._current_speaker = self._participant_names[0]
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
"""Select a speaker from the participants based on handoff message.
@@ -76,7 +84,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
if isinstance(message, HandoffMessage):
self._current_speaker = message.target
# The latest handoff message should always target a valid participant.
assert self._current_speaker in self._participant_topic_types
assert self._current_speaker in self._participant_names
return self._current_speaker
return self._current_speaker
@@ -194,17 +202,22 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
component_config_schema = SwarmConfig
component_provider_override = "autogen_agentchat.teams.Swarm"
# TODO: Add * to the constructor to separate the positional parameters from the kwargs.
# This may be a breaking change so let's wait until a good time to do it.
def __init__(
self,
participants: List[ChatAgent],
termination_condition: TerminationCondition | None = None,
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
) -> None:
super().__init__(
participants,
group_chat_manager_name="SwarmGroupChatManager",
group_chat_manager_class=SwarmGroupChatManager,
termination_condition=termination_condition,
max_turns=max_turns,
runtime=runtime,
)
# The first participant must be able to produce handoff messages.
first_participant = self._participants[0]
@@ -213,19 +226,25 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
def _create_group_chat_manager_factory(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
) -> Callable[[], SwarmGroupChatManager]:
def _factory() -> SwarmGroupChatManager:
return SwarmGroupChatManager(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
output_message_queue,
termination_condition,
max_turns,
)

View File

@@ -1,7 +1,6 @@
import asyncio
import json
import logging
from typing import Any, AsyncGenerator, List
from typing import List
import pytest
from autogen_agentchat import EVENT_LOGGER_NAME
@@ -35,15 +34,6 @@ from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.replay import ReplayChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
from openai.types.completion_usage import CompletionUsage
from utils import FileLogHandler
logger = logging.getLogger(EVENT_LOGGER_NAME)
@@ -51,22 +41,6 @@ logger.setLevel(logging.DEBUG)
logger.addHandler(FileLogHandler("test_assistant_agent.log"))
class _MockChatCompletion:
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
self._saved_chat_completions = chat_completions
self.curr_index = 0
self.calls: List[List[LLMMessage]] = []
async def mock_create(
self, *args: Any, **kwargs: Any
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
self.calls.append(kwargs["messages"]) # Save the call
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self.curr_index]
self.curr_index += 1
return completion
def _pass_function(input: str) -> str:
return "pass"
@@ -81,69 +55,23 @@ async def _echo_function(input: str) -> str:
@pytest.mark.asyncio
async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content="Calling pass function",
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task"}),
),
)
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="pass", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
thought="Calling pass function",
cached=False,
),
"pass",
"TERMINATE",
],
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
)
agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
tools=[
_pass_function,
_fail_function,
@@ -168,7 +96,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[4].models_usage is None
# Test streaming.
mock.curr_index = 0 # Reset the mock
model_client.reset()
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
@@ -181,7 +109,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
await agent2.load_state(state)
@@ -190,64 +118,33 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.mark.asyncio
async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task"}),
),
)
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
async def test_run_with_tools_and_reflection() -> None:
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
CreateResult(
finish_reason="stop",
content="Hello",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
CreateResult(
finish_reason="stop",
content="TERMINATE",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
],
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
)
agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
reflect_on_tool_use=True,
)
@@ -269,7 +166,7 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) ->
assert result.messages[3].models_usage.prompt_tokens == 10
# Test streaming.
mock.curr_index = 0 # pyright: ignore
model_client.reset()
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
@@ -282,7 +179,7 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) ->
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
tools=[
_pass_function,
_fail_function,
@@ -295,86 +192,28 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) ->
@pytest.mark.asyncio
async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content="Calling pass and echo functions",
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task1"}),
),
),
ChatCompletionMessageToolCall(
id="2",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task2"}),
),
),
ChatCompletionMessageToolCall(
id="3",
type="function",
function=Function(
name="_echo_function",
arguments=json.dumps({"input": "task3"}),
),
),
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="pass", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
async def test_run_with_parallel_tools() -> None:
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[
FunctionCall(id="1", arguments=json.dumps({"input": "task1"}), name="_pass_function"),
FunctionCall(id="2", arguments=json.dumps({"input": "task2"}), name="_pass_function"),
FunctionCall(id="3", arguments=json.dumps({"input": "task3"}), name="_echo_function"),
],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
thought="Calling pass and echo functions",
cached=False,
),
"pass",
"TERMINATE",
],
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
)
agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
tools=[
_pass_function,
_fail_function,
@@ -411,7 +250,7 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[4].models_usage is None
# Test streaming.
mock.curr_index = 0 # Reset the mock
model_client.reset()
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
@@ -424,7 +263,7 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
await agent2.load_state(state)
@@ -433,86 +272,27 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.mark.asyncio
async def test_run_with_parallel_tools_with_empty_call_ids(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task1"}),
),
),
ChatCompletionMessageToolCall(
id="",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task2"}),
),
),
ChatCompletionMessageToolCall(
id="",
type="function",
function=Function(
name="_echo_function",
arguments=json.dumps({"input": "task3"}),
),
),
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="pass", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[
FunctionCall(id="", arguments=json.dumps({"input": "task1"}), name="_pass_function"),
FunctionCall(id="", arguments=json.dumps({"input": "task2"}), name="_pass_function"),
FunctionCall(id="", arguments=json.dumps({"input": "task3"}), name="_echo_function"),
],
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
),
"pass",
"TERMINATE",
],
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
)
agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
tools=[
_pass_function,
_fail_function,
@@ -547,7 +327,7 @@ async def test_run_with_parallel_tools_with_empty_call_ids(monkeypatch: pytest.M
assert result.messages[3].models_usage is None
# Test streaming.
mock.curr_index = 0 # Reset the mock
model_client.reset()
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
@@ -560,7 +340,7 @@ async def test_run_with_parallel_tools_with_empty_call_ids(monkeypatch: pytest.M
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
await agent2.load_state(state)
@@ -569,43 +349,24 @@ async def test_run_with_parallel_tools_with_empty_call_ids(monkeypatch: pytest.M
@pytest.mark.asyncio
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
async def test_handoffs() -> None:
handoff = Handoff(target="agent2")
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name=handoff.name,
arguments=json.dumps({}),
),
)
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=42, completion_tokens=43, total_tokens=85),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="function_calls",
content=[
FunctionCall(id="1", arguments=json.dumps({}), name=handoff.name),
],
usage=RequestUsage(prompt_tokens=42, completion_tokens=43),
cached=False,
)
],
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
)
tool_use_agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
tools=[
_pass_function,
_fail_function,
@@ -630,7 +391,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[3].models_usage is None
# Test streaming.
mock.curr_index = 0 # pyright: ignore
model_client.reset()
index = 0
async for message in tool_use_agent.run_stream(task="task"):
if isinstance(message, TaskResult):
@@ -642,28 +403,10 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.mark.asyncio
async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Hello", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
model_client = ReplayChatCompletionClient(["Hello"])
agent = AssistantAgent(
name="assistant",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
)
# Generate a random base64 image.
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
@@ -698,7 +441,7 @@ async def test_invalid_model_capabilities() -> None:
@pytest.mark.asyncio
async def test_remove_images(monkeypatch: pytest.MonkeyPatch) -> None:
async def test_remove_images() -> None:
model = "random-model"
model_client_1 = OpenAIChatCompletionClient(
model=model,
@@ -732,28 +475,19 @@ async def test_remove_images(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.mark.asyncio
async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Response to message 1", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
model_client = ReplayChatCompletionClient(
[
CreateResult(
finish_reason="stop",
content="Response to message 1",
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
cached=False,
)
]
)
agent = AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
)
# Create a list of chat messages
@@ -779,7 +513,7 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[2].models_usage.prompt_tokens == 10
# Test run_stream method with list of messages
mock.curr_index = 0 # Reset mock index using public attribute
model_client.reset() # Reset the mock client
index = 0
async for message in agent.run_stream(task=messages):
if isinstance(message, TaskResult):
@@ -791,29 +525,11 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.mark.asyncio
async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Response to message 3", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
model_client = ReplayChatCompletionClient(["Response to message 3"])
model_context = BufferedChatCompletionContext(buffer_size=2)
agent = AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
model_context=model_context,
)
@@ -825,33 +541,15 @@ async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None:
await agent.run(task=messages)
# Check if the mock client is called with only the last two messages.
assert len(mock.calls) == 1
assert len(model_client.create_calls) == 1
# 2 message from the context + 1 system message
assert len(mock.calls[0]) == 3
assert len(model_client.create_calls[0]["messages"]) == 3
@pytest.mark.asyncio
async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Hello", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
model_client = ReplayChatCompletionClient(["Hello"])
b64_image_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
# Test basic memory properties and empty context
memory = ListMemory(name="test_memory")
@@ -883,7 +581,7 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
with pytest.raises(TypeError):
AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
memory="invalid", # type: ignore
)
@@ -891,9 +589,7 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
memory2 = ListMemory()
await memory2.add(MemoryContent(content="test instruction", mime_type=MemoryMimeType.TEXT))
agent = AssistantAgent(
"test_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), memory=[memory2]
)
agent = AssistantAgent("test_agent", model_client=model_client, memory=[memory2])
# Test dump and load component with memory
agent_config: ComponentModel = agent.dump_component()
@@ -916,30 +612,15 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
@pytest.mark.asyncio
async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Response to message 3", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
async def test_assistant_agent_declarative() -> None:
model_client = ReplayChatCompletionClient(
["Response to message 3"],
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
)
model_context = BufferedChatCompletionContext(buffer_size=2)
agent = AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
model_context=model_context,
memory=[ListMemory(name="test_memory")],
)
@@ -952,7 +633,7 @@ async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> N
agent3 = AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
model_client=model_client,
model_context=model_context,
tools=[
_pass_function,

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,10 @@
import asyncio
import json
import logging
from typing import Sequence
from typing import AsyncGenerator, Sequence
import pytest
import pytest_asyncio
from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import (
BaseChatAgent,
@@ -17,7 +18,7 @@ from autogen_agentchat.teams import (
MagenticOneGroupChat,
)
from autogen_agentchat.teams._group_chat._magentic_one._magentic_one_orchestrator import MagenticOneOrchestrator
from autogen_core import AgentId, CancellationToken
from autogen_core import AgentId, AgentRuntime, CancellationToken, SingleThreadedAgentRuntime
from autogen_ext.models.replay import ReplayChatCompletionClient
from utils import FileLogHandler
@@ -55,8 +56,19 @@ class _EchoAgent(BaseChatAgent):
self._last_message = None
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
if request.param == "single_threaded":
runtime = SingleThreadedAgentRuntime()
runtime.start()
yield runtime
await runtime.stop()
elif request.param == "embedded":
yield None
@pytest.mark.asyncio
async def test_magentic_one_group_chat_cancellation() -> None:
async def test_magentic_one_group_chat_cancellation(runtime: AgentRuntime | None) -> None:
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
@@ -67,7 +79,9 @@ async def test_magentic_one_group_chat_cancellation() -> None:
)
# Set max_turns to a large number to avoid stopping due to max_turns before cancellation.
team = MagenticOneGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client)
team = MagenticOneGroupChat(
participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client, runtime=runtime
)
cancellation_token = CancellationToken()
run_task = asyncio.create_task(
team.run(
@@ -83,7 +97,7 @@ async def test_magentic_one_group_chat_cancellation() -> None:
@pytest.mark.asyncio
async def test_magentic_one_group_chat_basic() -> None:
async def test_magentic_one_group_chat_basic(runtime: AgentRuntime | None) -> None:
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
@@ -115,7 +129,9 @@ async def test_magentic_one_group_chat_basic() -> None:
],
)
team = MagenticOneGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client)
team = MagenticOneGroupChat(
participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client, runtime=runtime
)
result = await team.run(task="Write a program that prints 'Hello, world!'")
assert len(result.messages) == 5
assert result.messages[2].content == "Continue task"
@@ -124,16 +140,18 @@ async def test_magentic_one_group_chat_basic() -> None:
# Test save and load.
state = await team.save_state()
team2 = MagenticOneGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client)
team2 = MagenticOneGroupChat(
participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client, runtime=runtime
)
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
manager_1 = await team._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team._team_id), # pyright: ignore
AgentId(f"{team._group_chat_manager_name}_{team._team_id}", team._team_id), # pyright: ignore
MagenticOneOrchestrator, # pyright: ignore
) # pyright: ignore
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
AgentId(f"{team2._group_chat_manager_name}_{team2._team_id}", team2._team_id), # pyright: ignore
MagenticOneOrchestrator, # pyright: ignore
) # pyright: ignore
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
@@ -145,7 +163,7 @@ async def test_magentic_one_group_chat_basic() -> None:
@pytest.mark.asyncio
async def test_magentic_one_group_chat_with_stalls() -> None:
async def test_magentic_one_group_chat_with_stalls(runtime: AgentRuntime | None) -> None:
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
@@ -189,7 +207,10 @@ async def test_magentic_one_group_chat_with_stalls() -> None:
)
team = MagenticOneGroupChat(
participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client, max_stalls=2
participants=[agent_1, agent_2, agent_3, agent_4],
model_client=model_client,
max_stalls=2,
runtime=runtime,
)
result = await team.run(task="Write a program that prints 'Hello, world!'")
assert len(result.messages) == 6

View File

@@ -1,75 +1,34 @@
import asyncio
from typing import Any, AsyncGenerator, List
from typing import AsyncGenerator
import pytest
import pytest_asyncio
from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_ext.models.openai import OpenAIChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.completion_usage import CompletionUsage
from autogen_core import AgentRuntime, SingleThreadedAgentRuntime
from autogen_ext.models.replay import ReplayChatCompletionClient
class _MockChatCompletion:
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
self._saved_chat_completions = chat_completions
self._curr_index = 0
async def mock_create(
self, *args: Any, **kwargs: Any
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self._curr_index]
self._curr_index += 1
return completion
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
if request.param == "single_threaded":
runtime = SingleThreadedAgentRuntime()
runtime.start()
yield runtime
await runtime.stop()
elif request.param == "embedded":
yield None
@pytest.mark.asyncio
async def test_society_of_mind_agent(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="1", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="2", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="3", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="")
async def test_society_of_mind_agent(runtime: AgentRuntime | None) -> None:
model_client = ReplayChatCompletionClient(
["1", "2", "3"],
)
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
inner_termination = MaxMessageTermination(3)
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination, runtime=runtime)
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
response = await society_of_mind_agent.run(task="Count to 10.")
assert len(response.messages) == 4
@@ -84,14 +43,13 @@ async def test_society_of_mind_agent(monkeypatch: pytest.MonkeyPatch) -> None:
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
inner_termination = MaxMessageTermination(3)
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination, runtime=runtime)
society_of_mind_agent2 = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
await society_of_mind_agent2.load_state(state)
state2 = await society_of_mind_agent2.save_state()
assert state == state2
# Test serialization.
soc_agent_config = society_of_mind_agent.dump_component()
assert soc_agent_config.provider == "autogen_agentchat.agents.SocietyOfMindAgent"

View File

@@ -2,9 +2,10 @@ from __future__ import annotations
import logging
import warnings
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Union
from typing_extensions import Self
from autogen_core import EVENT_LOGGER_NAME, CancellationToken
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component
from autogen_core.models import (
ChatCompletionClient,
CreateResult,
@@ -13,13 +14,22 @@ from autogen_core.models import (
ModelFamily,
ModelInfo,
RequestUsage,
validate_model_info,
)
from autogen_core.tools import Tool, ToolSchema
from pydantic import BaseModel
logger = logging.getLogger(EVENT_LOGGER_NAME)
class ReplayChatCompletionClient(ChatCompletionClient):
class ReplayChatCompletionClientConfig(BaseModel):
"""ReplayChatCompletionClient configuration."""
chat_completions: Sequence[Union[str, CreateResult]]
model_info: Optional[ModelInfo] = None
class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompletionClientConfig]):
"""
A mock chat completion client that replays predefined responses using an index-based approach.
@@ -111,25 +121,37 @@ class ReplayChatCompletionClient(ChatCompletionClient):
"""
__protocol__: ChatCompletionClient
component_type = "replay_chat_completion_client"
component_provider_override = "autogen_ext.models.replay.ReplayChatCompletionClient"
component_config_schema = ReplayChatCompletionClientConfig
# TODO: Support FunctionCall in responses
# TODO: Support logprobs in Responses
# TODO: Support model capabilities
def __init__(
self,
chat_completions: Sequence[Union[str, CreateResult]],
model_info: Optional[ModelInfo] = None,
):
self.chat_completions = list(chat_completions)
self.provided_message_count = len(self.chat_completions)
self._model_info = ModelInfo(
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
)
if model_info is not None:
self._model_info = model_info
validate_model_info(self._model_info)
else:
self._model_info = ModelInfo(
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
)
self._total_available_tokens = 10000
self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._current_index = 0
self._cached_bool_value = True
self._create_calls: List[Dict[str, Any]] = []
@property
def create_calls(self) -> List[Dict[str, Any]]:
"""Return the arguments of the calls made to the create method."""
return self._create_calls
async def create(
self,
@@ -159,6 +181,15 @@ class ReplayChatCompletionClient(ChatCompletionClient):
self._update_total_usage()
self._current_index += 1
self._create_calls.append(
{
"messages": messages,
"tools": tools,
"json_output": json_output,
"extra_create_args": extra_create_args,
"cancellation_token": cancellation_token,
}
)
return response
async def create_stream(
@@ -259,3 +290,16 @@ class ReplayChatCompletionClient(ChatCompletionClient):
self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._current_index = 0
def _to_config(self) -> ReplayChatCompletionClientConfig:
return ReplayChatCompletionClientConfig(
chat_completions=self.chat_completions,
model_info=self._model_info,
)
@classmethod
def _from_config(cls, config: ReplayChatCompletionClientConfig) -> Self:
return cls(
chat_completions=config.chat_completions,
model_info=config.model_info,
)