mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Support for external agent runtime in AgentChat (#5843)
Resolves #4075 1. Introduce custom runtime parameter for all AgentChat teams (RoundRobinGroupChat, SelectorGroupChat, etc.). This is done by making sure each team's topics are isolated from other teams, and decoupling state from agent identities. Also, I removed the closure agent from the BaseGroupChat and use the group chat manager agent to relay messages to the output message queue. 2. Added unit tests to test scenarios with custom runtimes by using pytest fixture 3. Refactored existing unit tests to use ReplayChatCompletionClient with a few improvements to the client. 4. Fix a one-liner bug in AssistantAgent that caused deserialized agent to have handoffs. How to use it? ```python import asyncio from autogen_core import SingleThreadedAgentRuntime from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.conditions import TextMentionTermination from autogen_ext.models.replay import ReplayChatCompletionClient async def main() -> None: # Create a runtime runtime = SingleThreadedAgentRuntime() runtime.start() # Create a model client. model_client = ReplayChatCompletionClient( ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], ) # Create agents agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.") agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.") # Create a termination condition termination_condition = TextMentionTermination("10", sources=["assistant1", "assistant2"]) # Create a team team = RoundRobinGroupChat([agent1, agent2], runtime=runtime, termination_condition=termination_condition) # Run the team stream = team.run_stream(task="Count to 10.") async for message in stream: print(message) # Save the state. state = await team.save_state() # Load the state to an existing team. await team.load_state(state) # Run the team again model_client.reset() stream = team.run_stream(task="Count to 10.") async for message in stream: print(message) # Create a new team, with the same agent names. agent3 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.") agent4 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.") new_team = RoundRobinGroupChat([agent3, agent4], runtime=runtime, termination_condition=termination_condition) # Load the state to the new team. await new_team.load_state(state) # Run the new team model_client.reset() new_stream = new_team.run_stream(task="Count to 10.") async for message in new_stream: print(message) # Stop the runtime await runtime.stop() asyncio.run(main()) ``` TODOs as future PRs: 1. Documentation. 2. How to handle errors in custom runtime when the agent has exception? --------- Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
This commit is contained in:
@@ -1192,7 +1192,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
|
||||
name=self.name,
|
||||
model_client=self._model_client.dump_component(),
|
||||
tools=[tool.dump_component() for tool in self._tools],
|
||||
handoffs=list(self._handoffs.values()),
|
||||
handoffs=list(self._handoffs.values()) if self._handoffs else None,
|
||||
model_context=self._model_context.dump_component(),
|
||||
memory=[memory.dump_component() for memory in self._memory] if self._memory else None,
|
||||
description=self.description,
|
||||
|
||||
@@ -29,7 +29,6 @@ class TeamState(BaseState):
|
||||
"""State for a team of agents."""
|
||||
|
||||
agent_states: Mapping[str, Any] = Field(default_factory=dict)
|
||||
team_id: str = Field(default="")
|
||||
type: str = Field(default="TeamState")
|
||||
|
||||
|
||||
|
||||
@@ -2,29 +2,32 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Callable, List, Mapping, Sequence
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence
|
||||
|
||||
from autogen_core import (
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
AgentRuntime,
|
||||
AgentType,
|
||||
CancellationToken,
|
||||
ClosureAgent,
|
||||
ComponentBase,
|
||||
MessageContext,
|
||||
SingleThreadedAgentRuntime,
|
||||
TypeSubscription,
|
||||
)
|
||||
from autogen_core._closure_agent import ClosureContext
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, ModelClientStreamingChunkEvent, TextMessage
|
||||
from ...messages import (
|
||||
AgentEvent,
|
||||
BaseChatMessage,
|
||||
ChatMessage,
|
||||
ModelClientStreamingChunkEvent,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from ...state import TeamState
|
||||
from ._chat_agent_container import ChatAgentContainer
|
||||
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
|
||||
from ._events import GroupChatReset, GroupChatStart, GroupChatTermination
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
@@ -42,9 +45,11 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent],
|
||||
group_chat_manager_name: str,
|
||||
group_chat_manager_class: type[SequentialRoutedAgent],
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
):
|
||||
if len(participants) == 0:
|
||||
raise ValueError("At least one participant is required.")
|
||||
@@ -55,23 +60,42 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
self._termination_condition = termination_condition
|
||||
self._max_turns = max_turns
|
||||
|
||||
# Constants for the group chat.
|
||||
# The team ID is a UUID that is used to identify the team and its participants
|
||||
# in the agent runtime. It is used to create unique topic types for each participant.
|
||||
# Currently, team ID is binded to an object instance of the group chat class.
|
||||
# So if you create two instances of group chat, there will be two teams with different IDs.
|
||||
self._team_id = str(uuid.uuid4())
|
||||
self._group_topic_type = "group_topic"
|
||||
self._output_topic_type = "output_topic"
|
||||
self._group_chat_manager_topic_type = "group_chat_manager"
|
||||
self._participant_topic_types: List[str] = [participant.name for participant in participants]
|
||||
self._participant_descriptions: List[str] = [participant.description for participant in participants]
|
||||
self._collector_agent_type = "collect_output_messages"
|
||||
|
||||
# Constants for the closure agent to collect the output messages.
|
||||
self._stop_reason: str | None = None
|
||||
self._output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | None] = asyncio.Queue()
|
||||
# Constants for the group chat team.
|
||||
# The names are used to identify the agents within the team.
|
||||
# The names may not be unique across different teams.
|
||||
self._group_chat_manager_name = group_chat_manager_name
|
||||
self._participant_names: List[str] = [participant.name for participant in participants]
|
||||
self._participant_descriptions: List[str] = [participant.description for participant in participants]
|
||||
# The group chat topic type is used for broadcast communication among all participants and the group chat manager.
|
||||
self._group_topic_type = f"group_topic_{self._team_id}"
|
||||
# The group chat manager topic type is used for direct communication with the group chat manager.
|
||||
self._group_chat_manager_topic_type = f"{self._group_chat_manager_name}_{self._team_id}"
|
||||
# The participant topic types are used for direct communication with each participant.
|
||||
self._participant_topic_types: List[str] = [
|
||||
f"{participant.name}_{self._team_id}" for participant in participants
|
||||
]
|
||||
# The output topic type is used for emitting streaming messages from the group chat.
|
||||
# The group chat manager will relay the messages to the output message queue.
|
||||
self._output_topic_type = f"output_topic_{self._team_id}"
|
||||
|
||||
# The queue for collecting the output messages.
|
||||
self._output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination] = asyncio.Queue()
|
||||
|
||||
# Create a runtime for the team.
|
||||
# TODO: The runtime should be created by a managed context.
|
||||
# Background exceptions must not be ignored as it results in non-surfaced exceptions and early team termination.
|
||||
self._runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
|
||||
if runtime is not None:
|
||||
self._runtime = runtime
|
||||
self._embedded_runtime = False
|
||||
else:
|
||||
# Use a embedded single-threaded runtime for the group chat.
|
||||
# Background exceptions must not be ignored as it results in non-surfaced exceptions and early team termination.
|
||||
self._runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
|
||||
self._embedded_runtime = True
|
||||
|
||||
# Flag to track if the group chat has been initialized.
|
||||
self._initialized = False
|
||||
@@ -82,10 +106,13 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
@abstractmethod
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
) -> Callable[[], SequentialRoutedAgent]: ...
|
||||
@@ -97,10 +124,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
agent: ChatAgent,
|
||||
) -> Callable[[], ChatAgentContainer]:
|
||||
def _factory() -> ChatAgentContainer:
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
assert id == AgentId(type=agent.name, key=self._team_id)
|
||||
container = ChatAgentContainer(parent_topic_type, output_topic_type, agent)
|
||||
assert container.id == id
|
||||
return container
|
||||
|
||||
return _factory
|
||||
@@ -110,9 +134,8 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
group_chat_manager_agent_type = AgentType(self._group_chat_manager_topic_type)
|
||||
|
||||
# Register participants.
|
||||
for participant, participant_topic_type in zip(self._participants, self._participant_topic_types, strict=False):
|
||||
# Use the participant topic type as the agent type.
|
||||
agent_type = participant_topic_type
|
||||
# Use the participant topic type as the agent type.
|
||||
for participant, agent_type in zip(self._participants, self._participant_topic_types, strict=True):
|
||||
# Register the participant factory.
|
||||
await ChatAgentContainer.register(
|
||||
runtime,
|
||||
@@ -120,7 +143,9 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
factory=self._create_participant_factory(self._group_topic_type, self._output_topic_type, participant),
|
||||
)
|
||||
# Add subscriptions for the participant.
|
||||
await runtime.add_subscription(TypeSubscription(topic_type=participant_topic_type, agent_type=agent_type))
|
||||
# The participant should be able to receive messages from its own topic.
|
||||
await runtime.add_subscription(TypeSubscription(topic_type=agent_type, agent_type=agent_type))
|
||||
# The participant should be able to receive messages from the group topic.
|
||||
await runtime.add_subscription(TypeSubscription(topic_type=self._group_topic_type, agent_type=agent_type))
|
||||
|
||||
# Register the group chat manager.
|
||||
@@ -128,50 +153,33 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
runtime,
|
||||
type=group_chat_manager_agent_type.type,
|
||||
factory=self._create_group_chat_manager_factory(
|
||||
name=self._group_chat_manager_name,
|
||||
group_topic_type=self._group_topic_type,
|
||||
output_topic_type=self._output_topic_type,
|
||||
participant_names=self._participant_names,
|
||||
participant_topic_types=self._participant_topic_types,
|
||||
participant_descriptions=self._participant_descriptions,
|
||||
output_message_queue=self._output_message_queue,
|
||||
termination_condition=self._termination_condition,
|
||||
max_turns=self._max_turns,
|
||||
),
|
||||
)
|
||||
# Add subscriptions for the group chat manager.
|
||||
# The group chat manager should be able to receive messages from the its own topic.
|
||||
await runtime.add_subscription(
|
||||
TypeSubscription(
|
||||
topic_type=self._group_chat_manager_topic_type, agent_type=group_chat_manager_agent_type.type
|
||||
)
|
||||
)
|
||||
# The group chat manager should be able to receive messages from the group topic.
|
||||
await runtime.add_subscription(
|
||||
TypeSubscription(topic_type=self._group_topic_type, agent_type=group_chat_manager_agent_type.type)
|
||||
)
|
||||
|
||||
async def collect_output_messages(
|
||||
_runtime: ClosureContext,
|
||||
message: GroupChatStart | GroupChatMessage | GroupChatTermination,
|
||||
ctx: MessageContext,
|
||||
) -> None:
|
||||
"""Collect output messages from the group chat."""
|
||||
if isinstance(message, GroupChatStart):
|
||||
if message.messages is not None:
|
||||
for msg in message.messages:
|
||||
event_logger.info(msg)
|
||||
await self._output_message_queue.put(msg)
|
||||
elif isinstance(message, GroupChatMessage):
|
||||
event_logger.info(message.message)
|
||||
await self._output_message_queue.put(message.message)
|
||||
elif isinstance(message, GroupChatTermination):
|
||||
event_logger.info(message.message)
|
||||
self._stop_reason = message.message.content
|
||||
|
||||
await ClosureAgent.register_closure(
|
||||
runtime,
|
||||
type=self._collector_agent_type,
|
||||
closure=collect_output_messages,
|
||||
subscriptions=lambda: [
|
||||
TypeSubscription(topic_type=self._output_topic_type, agent_type=self._collector_agent_type),
|
||||
],
|
||||
# The group chat manager will relay the messages from output topic to the output message queue.
|
||||
await runtime.add_subscription(
|
||||
TypeSubscription(topic_type=self._output_topic_type, agent_type=group_chat_manager_agent_type.type)
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def run(
|
||||
@@ -400,26 +408,40 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
raise ValueError("The team is already running, it cannot run again until it is stopped.")
|
||||
self._is_running = True
|
||||
|
||||
# Start the runtime.
|
||||
# TODO: The runtime should be started by a managed context.
|
||||
self._runtime.start()
|
||||
if self._embedded_runtime:
|
||||
# Start the embedded runtime.
|
||||
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
|
||||
self._runtime.start()
|
||||
|
||||
if not self._initialized:
|
||||
await self._init(self._runtime)
|
||||
|
||||
# Start a coroutine to stop the runtime and signal the output message queue is complete.
|
||||
async def stop_runtime() -> None:
|
||||
try:
|
||||
await self._runtime.stop_when_idle()
|
||||
finally:
|
||||
await self._output_message_queue.put(None)
|
||||
shutdown_task: asyncio.Task[None] | None = None
|
||||
if self._embedded_runtime:
|
||||
|
||||
shutdown_task = asyncio.create_task(stop_runtime())
|
||||
async def stop_runtime() -> None:
|
||||
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
|
||||
try:
|
||||
# This will propagate any exceptions raised.
|
||||
await self._runtime.stop_when_idle()
|
||||
finally:
|
||||
# Stop the consumption of messages and end the stream.
|
||||
# NOTE: we also need to put a GroupChatTermination event here because when the group chat
|
||||
# has an exception, the group chat manager may not be able to put a GroupChatTermination event in the queue.
|
||||
await self._output_message_queue.put(
|
||||
GroupChatTermination(
|
||||
message=StopMessage(content="Exception occurred.", source=self._group_chat_manager_name)
|
||||
)
|
||||
)
|
||||
|
||||
# Create a background task to stop the runtime when the group chat
|
||||
# is stopped or has an exception.
|
||||
shutdown_task = asyncio.create_task(stop_runtime())
|
||||
|
||||
try:
|
||||
# Run the team by sending the start message to the group chat manager.
|
||||
# The group chat manager will start the group chat by relaying the message to the participants
|
||||
# and the closure agent.
|
||||
# and the group chat manager.
|
||||
await self._runtime.send_message(
|
||||
GroupChatStart(messages=messages),
|
||||
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
|
||||
@@ -427,6 +449,7 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
)
|
||||
# Collect the output messages in order.
|
||||
output_messages: List[AgentEvent | ChatMessage] = []
|
||||
stop_reason: str | None = None
|
||||
# Yield the messsages until the queue is empty.
|
||||
while True:
|
||||
message_future = asyncio.ensure_future(self._output_message_queue.get())
|
||||
@@ -434,7 +457,13 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
cancellation_token.link_future(message_future)
|
||||
# Wait for the next message, this will raise an exception if the task is cancelled.
|
||||
message = await message_future
|
||||
if message is None:
|
||||
if isinstance(message, GroupChatTermination):
|
||||
# If the message is None, it means the group chat has terminated.
|
||||
# TODO: how do we handle termination when the runtime is not embedded
|
||||
# and there is an exception in the group chat?
|
||||
# The group chat manager may not be able to put a GroupChatTermination event in the queue,
|
||||
# and this loop will never end.
|
||||
stop_reason = message.message.content
|
||||
break
|
||||
yield message
|
||||
if isinstance(message, ModelClientStreamingChunkEvent):
|
||||
@@ -443,14 +472,14 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
output_messages.append(message)
|
||||
|
||||
# Yield the final result.
|
||||
yield TaskResult(messages=output_messages, stop_reason=self._stop_reason)
|
||||
yield TaskResult(messages=output_messages, stop_reason=stop_reason)
|
||||
|
||||
finally:
|
||||
# Wait for the shutdown task to finish.
|
||||
try:
|
||||
# This will propagate any exceptions raised in the shutdown task.
|
||||
# We need to ensure we cleanup though.
|
||||
await shutdown_task
|
||||
if shutdown_task is not None:
|
||||
# Wait for the shutdown task to finish.
|
||||
# This will propagate any exceptions raised.
|
||||
await shutdown_task
|
||||
finally:
|
||||
# Clear the output message queue.
|
||||
while not self._output_message_queue.empty():
|
||||
@@ -506,8 +535,10 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
raise RuntimeError("The group chat is currently running. It must be stopped before it can be reset.")
|
||||
self._is_running = True
|
||||
|
||||
# Start the runtime.
|
||||
self._runtime.start()
|
||||
if self._embedded_runtime:
|
||||
# Start the runtime.
|
||||
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
|
||||
self._runtime.start()
|
||||
|
||||
try:
|
||||
# Send a reset messages to all participants.
|
||||
@@ -522,11 +553,12 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
|
||||
)
|
||||
finally:
|
||||
# Stop the runtime.
|
||||
await self._runtime.stop_when_idle()
|
||||
if self._embedded_runtime:
|
||||
# Stop the runtime.
|
||||
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
|
||||
await self._runtime.stop_when_idle()
|
||||
|
||||
# Reset the output message queue.
|
||||
self._stop_reason = None
|
||||
while not self._output_message_queue.empty():
|
||||
self._output_message_queue.get_nowait()
|
||||
|
||||
@@ -534,7 +566,32 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
self._is_running = False
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the group chat team."""
|
||||
"""Save the state of the group chat team.
|
||||
|
||||
The state is saved by calling the :meth:`~autogen_core.AgentRuntime.agent_save_state` method
|
||||
on each participant and the group chat manager with their internal agent ID.
|
||||
The state is returned as a nested dictionary: a dictionary with key `agent_states`,
|
||||
which is a dictionary the agent names as keys and the state as values.
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
{
|
||||
"agent_states": {
|
||||
"agent1": ...,
|
||||
"agent2": ...,
|
||||
"RoundRobinGroupChatManager": ...
|
||||
}
|
||||
}
|
||||
|
||||
.. note::
|
||||
|
||||
Starting v0.4.9, the state is using the agent name as the key instead of the agent ID,
|
||||
and the `team_id` field is removed from the state. This is to allow the state to be
|
||||
portable across different teams and runtimes. States saved with the old format
|
||||
may not be compatible with the new format in the future.
|
||||
|
||||
"""
|
||||
|
||||
if not self._initialized:
|
||||
raise RuntimeError("The group chat has not been initialized. It must be run before it can be saved.")
|
||||
|
||||
@@ -543,15 +600,31 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
self._is_running = True
|
||||
|
||||
try:
|
||||
# Save the state of the runtime. This will save the state of the participants and the group chat manager.
|
||||
agent_states = await self._runtime.save_state()
|
||||
return TeamState(agent_states=agent_states, team_id=self._team_id).model_dump()
|
||||
# Store state of each agent by their name.
|
||||
# NOTE: we don't use the agent ID as the key here because we need to be able to decouple
|
||||
# the state of the agents from their identities in the agent runtime.
|
||||
agent_states: Dict[str, Mapping[str, Any]] = {}
|
||||
# Save the state of all participants.
|
||||
for name, agent_type in zip(self._participant_names, self._participant_topic_types, strict=True):
|
||||
agent_id = AgentId(type=agent_type, key=self._team_id)
|
||||
# NOTE: We are using the runtime's save state method rather than the agent instance's
|
||||
# save_state method because we want to support saving state of remote agents.
|
||||
agent_states[name] = await self._runtime.agent_save_state(agent_id)
|
||||
# Save the state of the group chat manager.
|
||||
agent_id = AgentId(type=self._group_chat_manager_topic_type, key=self._team_id)
|
||||
agent_states[self._group_chat_manager_name] = await self._runtime.agent_save_state(agent_id)
|
||||
return TeamState(agent_states=agent_states).model_dump()
|
||||
finally:
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the group chat team."""
|
||||
"""Load an external state and overwrite the current state of the group chat team.
|
||||
|
||||
The state is loaded by calling the :meth:`~autogen_core.AgentRuntime.agent_load_state` method
|
||||
on each participant and the group chat manager with their internal agent ID.
|
||||
See :meth:`~autogen_agentchat.teams.BaseGroupChat.save_state` for the expected format of the state.
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self._init(self._runtime)
|
||||
|
||||
@@ -560,10 +633,25 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
|
||||
self._is_running = True
|
||||
|
||||
try:
|
||||
# Load the state of the runtime. This will load the state of the participants and the group chat manager.
|
||||
team_state = TeamState.model_validate(state)
|
||||
self._team_id = team_state.team_id
|
||||
await self._runtime.load_state(team_state.agent_states)
|
||||
# Load the state of all participants.
|
||||
for name, agent_type in zip(self._participant_names, self._participant_topic_types, strict=True):
|
||||
agent_id = AgentId(type=agent_type, key=self._team_id)
|
||||
if name not in team_state.agent_states:
|
||||
raise ValueError(f"Agent state for {name} not found in the saved state.")
|
||||
await self._runtime.agent_load_state(agent_id, team_state.agent_states[name])
|
||||
# Load the state of the group chat manager.
|
||||
agent_id = AgentId(type=self._group_chat_manager_topic_type, key=self._team_id)
|
||||
if self._group_chat_manager_name not in team_state.agent_states:
|
||||
raise ValueError(f"Agent state for {self._group_chat_manager_name} not found in the saved state.")
|
||||
await self._runtime.agent_load_state(agent_id, team_state.agent_states[self._group_chat_manager_name])
|
||||
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
"Invalid state format. The expected state format has changed since v0.4.9. "
|
||||
"Please read the release note on GitHub."
|
||||
) from e
|
||||
|
||||
finally:
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
|
||||
@@ -8,6 +8,7 @@ from ...base import TerminationCondition
|
||||
from ...messages import AgentEvent, ChatMessage, StopMessage
|
||||
from ._events import (
|
||||
GroupChatAgentResponse,
|
||||
GroupChatMessage,
|
||||
GroupChatRequestPublish,
|
||||
GroupChatReset,
|
||||
GroupChatStart,
|
||||
@@ -30,14 +31,18 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
):
|
||||
super().__init__(description="Group chat manager")
|
||||
self._name = name
|
||||
self._group_topic_type = group_topic_type
|
||||
self._output_topic_type = output_topic_type
|
||||
if len(participant_topic_types) != len(participant_descriptions):
|
||||
@@ -46,9 +51,13 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
raise ValueError("The participant topic types must be unique.")
|
||||
if group_topic_type in participant_topic_types:
|
||||
raise ValueError("The group topic type must not be in the participant topic types.")
|
||||
self._participant_topic_types = participant_topic_types
|
||||
self._participant_names = participant_names
|
||||
self._participant_name_to_topic_type = {
|
||||
name: topic_type for name, topic_type in zip(participant_names, participant_topic_types, strict=True)
|
||||
}
|
||||
self._participant_descriptions = participant_descriptions
|
||||
self._message_thread: List[AgentEvent | ChatMessage] = []
|
||||
self._output_message_queue = output_message_queue
|
||||
self._termination_condition = termination_condition
|
||||
if max_turns is not None and max_turns <= 0:
|
||||
raise ValueError("The maximum number of turns must be greater than 0.")
|
||||
@@ -62,11 +71,11 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
# Check if the conversation has already terminated.
|
||||
if self._termination_condition is not None and self._termination_condition.terminated:
|
||||
early_stop_message = StopMessage(
|
||||
content="The group chat has already terminated.", source="Group chat manager"
|
||||
)
|
||||
await self.publish_message(
|
||||
GroupChatTermination(message=early_stop_message), topic_id=DefaultTopicId(type=self._output_topic_type)
|
||||
content="The group chat has already terminated.",
|
||||
source=self._name,
|
||||
)
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(early_stop_message)
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
@@ -76,8 +85,11 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
if message.messages is not None:
|
||||
# Log all messages at once
|
||||
await self.publish_message(
|
||||
GroupChatStart(messages=message.messages), topic_id=DefaultTopicId(type=self._output_topic_type)
|
||||
GroupChatStart(messages=message.messages),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
for msg in message.messages:
|
||||
await self._output_message_queue.put(msg)
|
||||
|
||||
# Relay all messages at once to participants
|
||||
await self.publish_message(
|
||||
@@ -93,19 +105,21 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
if self._termination_condition is not None:
|
||||
stop_message = await self._termination_condition(message.messages)
|
||||
if stop_message is not None:
|
||||
await self.publish_message(
|
||||
GroupChatTermination(message=stop_message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Stop the group chat and reset the termination condition.
|
||||
# Reset the termination condition.
|
||||
await self._termination_condition.reset()
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(stop_message)
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Select a speaker to start/continue the conversation
|
||||
speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
|
||||
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
|
||||
# Link the select speaker future to the cancellation token.
|
||||
ctx.cancellation_token.link_future(speaker_topic_type_future)
|
||||
speaker_topic_type = await speaker_topic_type_future
|
||||
ctx.cancellation_token.link_future(speaker_name_future)
|
||||
speaker_name = await speaker_name_future
|
||||
if speaker_name not in self._participant_name_to_topic_type:
|
||||
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
|
||||
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
|
||||
await self.publish_message(
|
||||
GroupChatRequestPublish(),
|
||||
topic_id=DefaultTopicId(type=speaker_topic_type),
|
||||
@@ -127,12 +141,12 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
if self._termination_condition is not None:
|
||||
stop_message = await self._termination_condition(delta)
|
||||
if stop_message is not None:
|
||||
await self.publish_message(
|
||||
GroupChatTermination(message=stop_message), topic_id=DefaultTopicId(type=self._output_topic_type)
|
||||
)
|
||||
# Stop the group chat and reset the termination conditions and turn count.
|
||||
# Reset the termination conditions and turn count.
|
||||
await self._termination_condition.reset()
|
||||
self._current_turn = 0
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(stop_message)
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Increment the turn count.
|
||||
@@ -142,28 +156,45 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
if self._current_turn >= self._max_turns:
|
||||
stop_message = StopMessage(
|
||||
content=f"Maximum number of turns {self._max_turns} reached.",
|
||||
source="Group chat manager",
|
||||
source=self._name,
|
||||
)
|
||||
await self.publish_message(
|
||||
GroupChatTermination(message=stop_message), topic_id=DefaultTopicId(type=self._output_topic_type)
|
||||
)
|
||||
# Stop the group chat and reset the termination conditions and turn count.
|
||||
# Reset the termination conditions and turn count.
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._current_turn = 0
|
||||
# Signal termination to the caller of the team.
|
||||
await self._signal_termination(stop_message)
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Select a speaker to continue the conversation.
|
||||
speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
|
||||
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
|
||||
# Link the select speaker future to the cancellation token.
|
||||
ctx.cancellation_token.link_future(speaker_topic_type_future)
|
||||
speaker_topic_type = await speaker_topic_type_future
|
||||
ctx.cancellation_token.link_future(speaker_name_future)
|
||||
speaker_name = await speaker_name_future
|
||||
if speaker_name not in self._participant_name_to_topic_type:
|
||||
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
|
||||
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
|
||||
await self.publish_message(
|
||||
GroupChatRequestPublish(),
|
||||
topic_id=DefaultTopicId(type=speaker_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
|
||||
async def _signal_termination(self, message: StopMessage) -> None:
|
||||
termination_event = GroupChatTermination(message=message)
|
||||
# Log the early stop message.
|
||||
await self.publish_message(
|
||||
termination_event,
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Put the termination event in the output message queue.
|
||||
await self._output_message_queue.put(termination_event)
|
||||
|
||||
@event
|
||||
async def handle_group_chat_message(self, message: GroupChatMessage, ctx: MessageContext) -> None:
|
||||
await self._output_message_queue.put(message.message)
|
||||
|
||||
@rpc
|
||||
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
|
||||
# Reset the group chat manager.
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Callable, List
|
||||
|
||||
from autogen_core import Component, ComponentModel
|
||||
from autogen_core import AgentRuntime, Component, ComponentModel
|
||||
from autogen_core.models import ChatCompletionClient
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||
from ....base import ChatAgent, TerminationCondition
|
||||
from ....messages import AgentEvent, ChatMessage
|
||||
from .._base_group_chat import BaseGroupChat
|
||||
from .._events import GroupChatTermination
|
||||
from ._magentic_one_orchestrator import MagenticOneOrchestrator
|
||||
from ._prompts import ORCHESTRATOR_FINAL_ANSWER_PROMPT
|
||||
|
||||
@@ -97,14 +100,17 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
|
||||
*,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = 20,
|
||||
runtime: AgentRuntime | None = None,
|
||||
max_stalls: int = 3,
|
||||
final_answer_prompt: str = ORCHESTRATOR_FINAL_ANSWER_PROMPT,
|
||||
):
|
||||
super().__init__(
|
||||
participants,
|
||||
group_chat_manager_name="MagenticOneOrchestrator",
|
||||
group_chat_manager_class=MagenticOneOrchestrator,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
# Validate the participants.
|
||||
@@ -116,22 +122,28 @@ class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig])
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
) -> Callable[[], MagenticOneOrchestrator]:
|
||||
return lambda: MagenticOneOrchestrator(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
max_turns,
|
||||
self._model_client,
|
||||
self._max_stalls,
|
||||
self._final_answer_prompt,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@@ -53,28 +54,33 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
max_turns: int | None,
|
||||
model_client: ChatCompletionClient,
|
||||
max_stalls: int,
|
||||
final_answer_prompt: str,
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
):
|
||||
super().__init__(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
)
|
||||
self._model_client = model_client
|
||||
self._max_stalls = max_stalls
|
||||
self._final_answer_prompt = final_answer_prompt
|
||||
self._name = "MagenticOneOrchestrator"
|
||||
self._max_json_retries = 10
|
||||
self._task = ""
|
||||
self._facts = ""
|
||||
@@ -84,7 +90,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
|
||||
# Produce a team description. Each agent sould appear on a single line.
|
||||
self._team_description = ""
|
||||
for topic_type, description in zip(self._participant_topic_types, self._participant_descriptions, strict=True):
|
||||
for topic_type, description in zip(self._participant_names, self._participant_descriptions, strict=True):
|
||||
self._team_description += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
|
||||
self._team_description = self._team_description.strip()
|
||||
|
||||
@@ -122,9 +128,8 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
# Check if the conversation has already terminated.
|
||||
if self._termination_condition is not None and self._termination_condition.terminated:
|
||||
early_stop_message = StopMessage(content="The group chat has already terminated.", source=self._name)
|
||||
await self.publish_message(
|
||||
GroupChatTermination(message=early_stop_message), topic_id=DefaultTopicId(type=self._output_topic_type)
|
||||
)
|
||||
# Signal termination.
|
||||
await self._signal_termination(early_stop_message)
|
||||
# Stop the group chat.
|
||||
return
|
||||
assert message is not None and message.messages is not None
|
||||
@@ -132,8 +137,12 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
# Validate the group state given all the messages.
|
||||
await self.validate_group_state(message.messages)
|
||||
|
||||
# Log the message.
|
||||
# Log the message to the output topic.
|
||||
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||
# Log the message to the output queue.
|
||||
for msg in message.messages:
|
||||
await self._output_message_queue.put(msg)
|
||||
|
||||
# Outer Loop for first time
|
||||
# Create the initial task ledger
|
||||
#################################
|
||||
@@ -182,11 +191,10 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
if self._termination_condition is not None:
|
||||
stop_message = await self._termination_condition(delta)
|
||||
if stop_message is not None:
|
||||
await self.publish_message(
|
||||
GroupChatTermination(message=stop_message), topic_id=DefaultTopicId(type=self._output_topic_type)
|
||||
)
|
||||
# Stop the group chat and reset the termination conditions and turn count.
|
||||
# Reset the termination conditions.
|
||||
await self._termination_condition.reset()
|
||||
# Signal termination.
|
||||
await self._signal_termination(stop_message)
|
||||
return
|
||||
await self._orchestrate_step(ctx.cancellation_token)
|
||||
|
||||
@@ -233,7 +241,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
async def _reenter_outer_loop(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Re-enter Outer loop of the orchestrator after creating task ledger."""
|
||||
# Reset the agents
|
||||
for participant_topic_type in self._participant_topic_types:
|
||||
for participant_topic_type in self._participant_name_to_topic_type.values():
|
||||
await self._runtime.send_message(
|
||||
GroupChatReset(),
|
||||
recipient=AgentId(type=participant_topic_type, key=self.id.key),
|
||||
@@ -251,11 +259,13 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
# Save my copy
|
||||
self._message_thread.append(ledger_message)
|
||||
|
||||
# Log it
|
||||
# Log it to the output topic.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=ledger_message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Log it to the output queue.
|
||||
await self._output_message_queue.put(ledger_message)
|
||||
|
||||
# Broadcast
|
||||
await self.publish_message(
|
||||
@@ -278,7 +288,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
context = self._thread_to_context()
|
||||
|
||||
progress_ledger_prompt = self._get_progress_ledger_prompt(
|
||||
self._task, self._team_description, self._participant_topic_types
|
||||
self._task, self._team_description, self._participant_names
|
||||
)
|
||||
context.append(UserMessage(content=progress_ledger_prompt, source=self._name))
|
||||
progress_ledger: Dict[str, Any] = {}
|
||||
@@ -292,10 +302,10 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
progress_ledger = json.loads(ledger_str)
|
||||
|
||||
# If the team consists of a single agent, deterministically set the next speaker
|
||||
if len(self._participant_topic_types) == 1:
|
||||
if len(self._participant_names) == 1:
|
||||
progress_ledger["next_speaker"] = {
|
||||
"reason": "The team consists of only one agent.",
|
||||
"answer": self._participant_topic_types[0],
|
||||
"answer": self._participant_names[0],
|
||||
}
|
||||
|
||||
# Validate the structure
|
||||
@@ -321,7 +331,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
# Validate the next speaker if the task is not yet complete
|
||||
if (
|
||||
not progress_ledger["is_request_satisfied"]["answer"]
|
||||
and progress_ledger["next_speaker"]["answer"] not in self._participant_topic_types
|
||||
and progress_ledger["next_speaker"]["answer"] not in self._participant_names
|
||||
):
|
||||
key_error = True
|
||||
break
|
||||
@@ -362,12 +372,14 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
message = TextMessage(content=progress_ledger["instruction_or_question"]["answer"], source=self._name)
|
||||
self._message_thread.append(message) # My copy
|
||||
|
||||
# Log it
|
||||
await self._log_message(f"Next Speaker: {progress_ledger['next_speaker']['answer']}")
|
||||
# Log it to the output topic.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Log it to the output queue.
|
||||
await self._output_message_queue.put(message)
|
||||
|
||||
# Broadcast it
|
||||
await self.publish_message( # Broadcast
|
||||
@@ -377,21 +389,18 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
)
|
||||
|
||||
# Request that the step be completed
|
||||
valid_next_speaker: bool = False
|
||||
next_speaker = progress_ledger["next_speaker"]["answer"]
|
||||
for participant_topic_type in self._participant_topic_types:
|
||||
if participant_topic_type == next_speaker:
|
||||
await self.publish_message(
|
||||
GroupChatRequestPublish(),
|
||||
topic_id=DefaultTopicId(type=next_speaker),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
valid_next_speaker = True
|
||||
break
|
||||
if not valid_next_speaker:
|
||||
# Check if the next speaker is valid
|
||||
if next_speaker not in self._participant_name_to_topic_type:
|
||||
raise ValueError(
|
||||
f"Invalid next speaker: {next_speaker} from the ledger, participants are: {self._participant_topic_types}"
|
||||
f"Invalid next speaker: {next_speaker} from the ledger, participants are: {self._participant_names}"
|
||||
)
|
||||
participant_topic_type = self._participant_name_to_topic_type[next_speaker]
|
||||
await self.publish_message(
|
||||
GroupChatRequestPublish(),
|
||||
topic_id=DefaultTopicId(type=participant_topic_type),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
async def _update_task_ledger(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Update the task ledger (outer loop) with the latest facts and plan."""
|
||||
@@ -436,11 +445,13 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
|
||||
self._message_thread.append(message) # My copy
|
||||
|
||||
# Log it
|
||||
# Log it to the output topic.
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
# Log it to the output queue.
|
||||
await self._output_message_queue.put(message)
|
||||
|
||||
# Broadcast
|
||||
await self.publish_message(
|
||||
@@ -449,13 +460,10 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
# Signal termination
|
||||
await self.publish_message(
|
||||
GroupChatTermination(message=StopMessage(content=reason, source=self._name)),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
# Signal termination
|
||||
await self._signal_termination(StopMessage(content=reason, source=self._name))
|
||||
|
||||
def _thread_to_context(self) -> List[LLMMessage]:
|
||||
"""Convert the message thread to a context for the model."""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, List, Mapping
|
||||
|
||||
from autogen_core import Component, ComponentModel
|
||||
from autogen_core import AgentRuntime, Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -9,6 +10,7 @@ from ...messages import AgentEvent, ChatMessage
|
||||
from ...state import RoundRobinManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
from ._events import GroupChatTermination
|
||||
|
||||
|
||||
class RoundRobinGroupChatManager(BaseGroupChatManager):
|
||||
@@ -16,18 +18,24 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
)
|
||||
@@ -60,8 +68,8 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
||||
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
|
||||
"""Select a speaker from the participants in a round-robin fashion."""
|
||||
current_speaker_index = self._next_speaker_index
|
||||
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_topic_types)
|
||||
current_speaker = self._participant_topic_types[current_speaker_index]
|
||||
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_names)
|
||||
current_speaker = self._participant_names[current_speaker_index]
|
||||
return current_speaker
|
||||
|
||||
|
||||
@@ -148,34 +156,45 @@ class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
|
||||
component_config_schema = RoundRobinGroupChatConfig
|
||||
component_provider_override = "autogen_agentchat.teams.RoundRobinGroupChat"
|
||||
|
||||
# TODO: Add * to the constructor to separate the positional parameters from the kwargs.
|
||||
# This may be a breaking change so let's wait until a good time to do it.
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent],
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
participants,
|
||||
group_chat_manager_name="RoundRobinGroupChatManager",
|
||||
group_chat_manager_class=RoundRobinGroupChatManager,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
) -> Callable[[], RoundRobinGroupChatManager]:
|
||||
def _factory() -> RoundRobinGroupChatManager:
|
||||
return RoundRobinGroupChatManager(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Mapping, Sequence
|
||||
|
||||
from autogen_core import Component, ComponentModel
|
||||
from autogen_core import AgentRuntime, Component, ComponentModel
|
||||
from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
@@ -19,6 +20,7 @@ from ...messages import (
|
||||
from ...state import SelectorManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
from ._events import GroupChatTermination
|
||||
|
||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||
|
||||
@@ -29,10 +31,13 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
model_client: ChatCompletionClient,
|
||||
@@ -42,10 +47,13 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
max_selector_attempts: int,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
)
|
||||
@@ -91,6 +99,11 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
if self._selector_func is not None:
|
||||
speaker = self._selector_func(thread)
|
||||
if speaker is not None:
|
||||
if speaker not in self._participant_names:
|
||||
raise ValueError(
|
||||
f"Selector function returned an invalid speaker name: {speaker}. "
|
||||
f"Expected one of: {self._participant_names}."
|
||||
)
|
||||
# Skip the model based selection.
|
||||
return speaker
|
||||
|
||||
@@ -100,7 +113,6 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
if isinstance(msg, BaseAgentEvent):
|
||||
# Ignore agent events.
|
||||
continue
|
||||
# The agent type must be the same as the topic type, which we use as the agent name.
|
||||
message = f"{msg.source}:"
|
||||
if isinstance(msg.content, str):
|
||||
message += f" {msg.content}"
|
||||
@@ -117,18 +129,18 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
) # Create some consistency for how messages are separated in the transcript
|
||||
history = "\n".join(history_messages)
|
||||
|
||||
# Construct agent roles, we are using the participant topic type as the agent name.
|
||||
# Construct agent roles.
|
||||
# Each agent sould appear on a single line.
|
||||
roles = ""
|
||||
for topic_type, description in zip(self._participant_topic_types, self._participant_descriptions, strict=True):
|
||||
for topic_type, description in zip(self._participant_names, self._participant_descriptions, strict=True):
|
||||
roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
|
||||
roles = roles.strip()
|
||||
|
||||
# Construct agent list to be selected, skip the previous speaker if not allowed.
|
||||
# Construct the candidate agent list to be selected from, skip the previous speaker if not allowed.
|
||||
if self._previous_speaker is not None and not self._allow_repeated_speaker:
|
||||
participants = [p for p in self._participant_topic_types if p != self._previous_speaker]
|
||||
participants = [p for p in self._participant_names if p != self._previous_speaker]
|
||||
else:
|
||||
participants = self._participant_topic_types
|
||||
participants = list(self._participant_names)
|
||||
assert len(participants) > 0
|
||||
|
||||
# Select the next speaker.
|
||||
@@ -157,7 +169,9 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
response = await self._model_client.create(messages=select_speaker_messages)
|
||||
assert isinstance(response.content, str)
|
||||
select_speaker_messages.append(AssistantMessage(content=response.content, source="selector"))
|
||||
mentions = self._mentioned_agents(response.content, self._participant_topic_types)
|
||||
# NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed.
|
||||
# This is because the model may still select the previous speaker, and we want to catch that.
|
||||
mentions = self._mentioned_agents(response.content, self._participant_names)
|
||||
if len(mentions) == 0:
|
||||
trace_logger.debug(f"Model failed to select a valid name: {response.content} (attempt {num_attempts})")
|
||||
feedback = f"No valid name was mentioned. Please select from: {str(participants)}."
|
||||
@@ -391,6 +405,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
|
||||
*,
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
selector_prompt: str = """You are in a role play game. The following roles are available:
|
||||
{roles}.
|
||||
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
|
||||
@@ -405,9 +420,11 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||
):
|
||||
super().__init__(
|
||||
participants,
|
||||
group_chat_manager_name="SelectorGroupChatManager",
|
||||
group_chat_manager_class=SelectorGroupChatManager,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
)
|
||||
# Validate the participants.
|
||||
if len(participants) < 2:
|
||||
@@ -420,18 +437,24 @@ Read the above conversation. Then select the next role from {participants} to pl
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
) -> Callable[[], BaseGroupChatManager]:
|
||||
return lambda: SelectorGroupChatManager(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
self._model_client,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, List, Mapping
|
||||
|
||||
from autogen_core import Component, ComponentModel
|
||||
from autogen_core import AgentRuntime, Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
@@ -8,6 +9,7 @@ from ...messages import AgentEvent, ChatMessage, HandoffMessage
|
||||
from ...state import SwarmManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
from ._events import GroupChatTermination
|
||||
|
||||
|
||||
class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
@@ -15,22 +17,28 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
)
|
||||
self._current_speaker = participant_topic_types[0]
|
||||
self._current_speaker = self._participant_names[0]
|
||||
|
||||
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
|
||||
"""Validate the start messages for the group chat."""
|
||||
@@ -38,9 +46,9 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
if messages:
|
||||
for message in messages:
|
||||
if isinstance(message, HandoffMessage):
|
||||
if message.target not in self._participant_topic_types:
|
||||
if message.target not in self._participant_names:
|
||||
raise ValueError(
|
||||
f"The target {message.target} is not one of the participants {self._participant_topic_types}. "
|
||||
f"The target {message.target} is not one of the participants {self._participant_names}. "
|
||||
"If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target."
|
||||
)
|
||||
return
|
||||
@@ -48,9 +56,9 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
# Check if there is a handoff message in the thread that is not targeting a valid participant.
|
||||
for existing_message in reversed(self._message_thread):
|
||||
if isinstance(existing_message, HandoffMessage):
|
||||
if existing_message.target not in self._participant_topic_types:
|
||||
if existing_message.target not in self._participant_names:
|
||||
raise ValueError(
|
||||
f"The existing handoff target {existing_message.target} is not one of the participants {self._participant_topic_types}. "
|
||||
f"The existing handoff target {existing_message.target} is not one of the participants {self._participant_names}. "
|
||||
"If you are resuming Swarm with a new task make sure to include in your task "
|
||||
"a HandoffMessage with a valid participant as the target. For example, if you are "
|
||||
"resuming from a HandoffTermination, make sure the new task is a HandoffMessage "
|
||||
@@ -65,7 +73,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
self._message_thread.clear()
|
||||
if self._termination_condition is not None:
|
||||
await self._termination_condition.reset()
|
||||
self._current_speaker = self._participant_topic_types[0]
|
||||
self._current_speaker = self._participant_names[0]
|
||||
|
||||
async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
|
||||
"""Select a speaker from the participants based on handoff message.
|
||||
@@ -76,7 +84,7 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
if isinstance(message, HandoffMessage):
|
||||
self._current_speaker = message.target
|
||||
# The latest handoff message should always target a valid participant.
|
||||
assert self._current_speaker in self._participant_topic_types
|
||||
assert self._current_speaker in self._participant_names
|
||||
return self._current_speaker
|
||||
return self._current_speaker
|
||||
|
||||
@@ -194,17 +202,22 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
||||
component_config_schema = SwarmConfig
|
||||
component_provider_override = "autogen_agentchat.teams.Swarm"
|
||||
|
||||
# TODO: Add * to the constructor to separate the positional parameters from the kwargs.
|
||||
# This may be a breaking change so let's wait until a good time to do it.
|
||||
def __init__(
|
||||
self,
|
||||
participants: List[ChatAgent],
|
||||
termination_condition: TerminationCondition | None = None,
|
||||
max_turns: int | None = None,
|
||||
runtime: AgentRuntime | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
participants,
|
||||
group_chat_manager_name="SwarmGroupChatManager",
|
||||
group_chat_manager_class=SwarmGroupChatManager,
|
||||
termination_condition=termination_condition,
|
||||
max_turns=max_turns,
|
||||
runtime=runtime,
|
||||
)
|
||||
# The first participant must be able to produce handoff messages.
|
||||
first_participant = self._participants[0]
|
||||
@@ -213,19 +226,25 @@ class Swarm(BaseGroupChat, Component[SwarmConfig]):
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
name: str,
|
||||
group_topic_type: str,
|
||||
output_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_names: List[str],
|
||||
participant_descriptions: List[str],
|
||||
output_message_queue: asyncio.Queue[AgentEvent | ChatMessage | GroupChatTermination],
|
||||
termination_condition: TerminationCondition | None,
|
||||
max_turns: int | None,
|
||||
) -> Callable[[], SwarmGroupChatManager]:
|
||||
def _factory() -> SwarmGroupChatManager:
|
||||
return SwarmGroupChatManager(
|
||||
name,
|
||||
group_topic_type,
|
||||
output_topic_type,
|
||||
participant_topic_types,
|
||||
participant_names,
|
||||
participant_descriptions,
|
||||
output_message_queue,
|
||||
termination_condition,
|
||||
max_turns,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, List
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
@@ -35,15 +34,6 @@ from autogen_core.models._model_client import ModelFamily
|
||||
from autogen_core.tools import FunctionTool
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from utils import FileLogHandler
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
@@ -51,22 +41,6 @@ logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(FileLogHandler("test_assistant_agent.log"))
|
||||
|
||||
|
||||
class _MockChatCompletion:
|
||||
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
||||
self._saved_chat_completions = chat_completions
|
||||
self.curr_index = 0
|
||||
self.calls: List[List[LLMMessage]] = []
|
||||
|
||||
async def mock_create(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
||||
self.calls.append(kwargs["messages"]) # Save the call
|
||||
await asyncio.sleep(0.1)
|
||||
completion = self._saved_chat_completions[self.curr_index]
|
||||
self.curr_index += 1
|
||||
return completion
|
||||
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
return "pass"
|
||||
|
||||
@@ -81,69 +55,23 @@ async def _echo_function(input: str) -> str:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="Calling pass function",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "task"}),
|
||||
),
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="pass", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
model_client = ReplayChatCompletionClient(
|
||||
[
|
||||
CreateResult(
|
||||
finish_reason="function_calls",
|
||||
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")],
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
thought="Calling pass function",
|
||||
cached=False,
|
||||
),
|
||||
"pass",
|
||||
"TERMINATE",
|
||||
],
|
||||
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
|
||||
)
|
||||
agent = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
@@ -168,7 +96,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert result.messages[4].models_usage is None
|
||||
|
||||
# Test streaming.
|
||||
mock.curr_index = 0 # Reset the mock
|
||||
model_client.reset()
|
||||
index = 0
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
@@ -181,7 +109,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
state = await agent.save_state()
|
||||
agent2 = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
)
|
||||
await agent2.load_state(state)
|
||||
@@ -190,64 +118,33 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "task"}),
|
||||
),
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
async def test_run_with_tools_and_reflection() -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
[
|
||||
CreateResult(
|
||||
finish_reason="function_calls",
|
||||
content=[FunctionCall(id="1", arguments=json.dumps({"input": "task"}), name="_pass_function")],
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
),
|
||||
CreateResult(
|
||||
finish_reason="stop",
|
||||
content="Hello",
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
),
|
||||
CreateResult(
|
||||
finish_reason="stop",
|
||||
content="TERMINATE",
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
),
|
||||
],
|
||||
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
|
||||
)
|
||||
agent = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
reflect_on_tool_use=True,
|
||||
)
|
||||
@@ -269,7 +166,7 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) ->
|
||||
assert result.messages[3].models_usage.prompt_tokens == 10
|
||||
|
||||
# Test streaming.
|
||||
mock.curr_index = 0 # pyright: ignore
|
||||
model_client.reset()
|
||||
index = 0
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
@@ -282,7 +179,7 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) ->
|
||||
state = await agent.save_state()
|
||||
agent2 = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
@@ -295,86 +192,28 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) ->
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="Calling pass and echo functions",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "task1"}),
|
||||
),
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
id="2",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "task2"}),
|
||||
),
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
id="3",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_echo_function",
|
||||
arguments=json.dumps({"input": "task3"}),
|
||||
),
|
||||
),
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="pass", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
async def test_run_with_parallel_tools() -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
[
|
||||
CreateResult(
|
||||
finish_reason="function_calls",
|
||||
content=[
|
||||
FunctionCall(id="1", arguments=json.dumps({"input": "task1"}), name="_pass_function"),
|
||||
FunctionCall(id="2", arguments=json.dumps({"input": "task2"}), name="_pass_function"),
|
||||
FunctionCall(id="3", arguments=json.dumps({"input": "task3"}), name="_echo_function"),
|
||||
],
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
thought="Calling pass and echo functions",
|
||||
cached=False,
|
||||
),
|
||||
"pass",
|
||||
"TERMINATE",
|
||||
],
|
||||
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
|
||||
)
|
||||
agent = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
@@ -411,7 +250,7 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert result.messages[4].models_usage is None
|
||||
|
||||
# Test streaming.
|
||||
mock.curr_index = 0 # Reset the mock
|
||||
model_client.reset()
|
||||
index = 0
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
@@ -424,7 +263,7 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
state = await agent.save_state()
|
||||
agent2 = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
)
|
||||
await agent2.load_state(state)
|
||||
@@ -433,86 +272,27 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_parallel_tools_with_empty_call_ids(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "task1"}),
|
||||
),
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
id="",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "task2"}),
|
||||
),
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
id="",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_echo_function",
|
||||
arguments=json.dumps({"input": "task3"}),
|
||||
),
|
||||
),
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="pass", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
[
|
||||
CreateResult(
|
||||
finish_reason="function_calls",
|
||||
content=[
|
||||
FunctionCall(id="", arguments=json.dumps({"input": "task1"}), name="_pass_function"),
|
||||
FunctionCall(id="", arguments=json.dumps({"input": "task2"}), name="_pass_function"),
|
||||
FunctionCall(id="", arguments=json.dumps({"input": "task3"}), name="_echo_function"),
|
||||
],
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
),
|
||||
"pass",
|
||||
"TERMINATE",
|
||||
],
|
||||
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
|
||||
)
|
||||
agent = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
@@ -547,7 +327,7 @@ async def test_run_with_parallel_tools_with_empty_call_ids(monkeypatch: pytest.M
|
||||
assert result.messages[3].models_usage is None
|
||||
|
||||
# Test streaming.
|
||||
mock.curr_index = 0 # Reset the mock
|
||||
model_client.reset()
|
||||
index = 0
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
@@ -560,7 +340,7 @@ async def test_run_with_parallel_tools_with_empty_call_ids(monkeypatch: pytest.M
|
||||
state = await agent.save_state()
|
||||
agent2 = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
)
|
||||
await agent2.load_state(state)
|
||||
@@ -569,43 +349,24 @@ async def test_run_with_parallel_tools_with_empty_call_ids(monkeypatch: pytest.M
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def test_handoffs() -> None:
|
||||
handoff = Handoff(target="agent2")
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=handoff.name,
|
||||
arguments=json.dumps({}),
|
||||
),
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=42, completion_tokens=43, total_tokens=85),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
model_client = ReplayChatCompletionClient(
|
||||
[
|
||||
CreateResult(
|
||||
finish_reason="function_calls",
|
||||
content=[
|
||||
FunctionCall(id="1", arguments=json.dumps({}), name=handoff.name),
|
||||
],
|
||||
usage=RequestUsage(prompt_tokens=42, completion_tokens=43),
|
||||
cached=False,
|
||||
)
|
||||
],
|
||||
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
|
||||
)
|
||||
tool_use_agent = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
@@ -630,7 +391,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert result.messages[3].models_usage is None
|
||||
|
||||
# Test streaming.
|
||||
mock.curr_index = 0 # pyright: ignore
|
||||
model_client.reset()
|
||||
index = 0
|
||||
async for message in tool_use_agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
@@ -642,28 +403,10 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="Hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
model_client = ReplayChatCompletionClient(["Hello"])
|
||||
agent = AssistantAgent(
|
||||
name="assistant",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
)
|
||||
# Generate a random base64 image.
|
||||
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
|
||||
@@ -698,7 +441,7 @@ async def test_invalid_model_capabilities() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_images(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def test_remove_images() -> None:
|
||||
model = "random-model"
|
||||
model_client_1 = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
@@ -732,28 +475,19 @@ async def test_remove_images(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="Response to message 1", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
model_client = ReplayChatCompletionClient(
|
||||
[
|
||||
CreateResult(
|
||||
finish_reason="stop",
|
||||
content="Response to message 1",
|
||||
usage=RequestUsage(prompt_tokens=10, completion_tokens=5),
|
||||
cached=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
agent = AssistantAgent(
|
||||
"test_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
)
|
||||
|
||||
# Create a list of chat messages
|
||||
@@ -779,7 +513,7 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert result.messages[2].models_usage.prompt_tokens == 10
|
||||
|
||||
# Test run_stream method with list of messages
|
||||
mock.curr_index = 0 # Reset mock index using public attribute
|
||||
model_client.reset() # Reset the mock client
|
||||
index = 0
|
||||
async for message in agent.run_stream(task=messages):
|
||||
if isinstance(message, TaskResult):
|
||||
@@ -791,29 +525,11 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="Response to message 3", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
model_client = ReplayChatCompletionClient(["Response to message 3"])
|
||||
model_context = BufferedChatCompletionContext(buffer_size=2)
|
||||
agent = AssistantAgent(
|
||||
"test_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
model_context=model_context,
|
||||
)
|
||||
|
||||
@@ -825,33 +541,15 @@ async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
await agent.run(task=messages)
|
||||
|
||||
# Check if the mock client is called with only the last two messages.
|
||||
assert len(mock.calls) == 1
|
||||
assert len(model_client.create_calls) == 1
|
||||
# 2 message from the context + 1 system message
|
||||
assert len(mock.calls[0]) == 3
|
||||
assert len(model_client.create_calls[0]["messages"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="Hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
model_client = ReplayChatCompletionClient(["Hello"])
|
||||
b64_image_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
|
||||
# Test basic memory properties and empty context
|
||||
memory = ListMemory(name="test_memory")
|
||||
@@ -883,7 +581,7 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
AssistantAgent(
|
||||
"test_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
memory="invalid", # type: ignore
|
||||
)
|
||||
|
||||
@@ -891,9 +589,7 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
memory2 = ListMemory()
|
||||
await memory2.add(MemoryContent(content="test instruction", mime_type=MemoryMimeType.TEXT))
|
||||
|
||||
agent = AssistantAgent(
|
||||
"test_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), memory=[memory2]
|
||||
)
|
||||
agent = AssistantAgent("test_agent", model_client=model_client, memory=[memory2])
|
||||
|
||||
# Test dump and load component with memory
|
||||
agent_config: ComponentModel = agent.dump_component()
|
||||
@@ -916,30 +612,15 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="Response to message 3", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
async def test_assistant_agent_declarative() -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
["Response to message 3"],
|
||||
model_info={"function_calling": True, "vision": True, "json_output": True, "family": ModelFamily.GPT_4O},
|
||||
)
|
||||
model_context = BufferedChatCompletionContext(buffer_size=2)
|
||||
agent = AssistantAgent(
|
||||
"test_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
model_context=model_context,
|
||||
memory=[ListMemory(name="test_memory")],
|
||||
)
|
||||
@@ -952,7 +633,7 @@ async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> N
|
||||
|
||||
agent3 = AssistantAgent(
|
||||
"test_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
model_client=model_client,
|
||||
model_context=model_context,
|
||||
tools=[
|
||||
_pass_function,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Sequence
|
||||
from typing import AsyncGenerator, Sequence
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
from autogen_agentchat.agents import (
|
||||
BaseChatAgent,
|
||||
@@ -17,7 +18,7 @@ from autogen_agentchat.teams import (
|
||||
MagenticOneGroupChat,
|
||||
)
|
||||
from autogen_agentchat.teams._group_chat._magentic_one._magentic_one_orchestrator import MagenticOneOrchestrator
|
||||
from autogen_core import AgentId, CancellationToken
|
||||
from autogen_core import AgentId, AgentRuntime, CancellationToken, SingleThreadedAgentRuntime
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from utils import FileLogHandler
|
||||
|
||||
@@ -55,8 +56,19 @@ class _EchoAgent(BaseChatAgent):
|
||||
self._last_message = None
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
|
||||
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
|
||||
if request.param == "single_threaded":
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
runtime.start()
|
||||
yield runtime
|
||||
await runtime.stop()
|
||||
elif request.param == "embedded":
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_magentic_one_group_chat_cancellation() -> None:
|
||||
async def test_magentic_one_group_chat_cancellation(runtime: AgentRuntime | None) -> None:
|
||||
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
||||
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
|
||||
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
||||
@@ -67,7 +79,9 @@ async def test_magentic_one_group_chat_cancellation() -> None:
|
||||
)
|
||||
|
||||
# Set max_turns to a large number to avoid stopping due to max_turns before cancellation.
|
||||
team = MagenticOneGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client)
|
||||
team = MagenticOneGroupChat(
|
||||
participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client, runtime=runtime
|
||||
)
|
||||
cancellation_token = CancellationToken()
|
||||
run_task = asyncio.create_task(
|
||||
team.run(
|
||||
@@ -83,7 +97,7 @@ async def test_magentic_one_group_chat_cancellation() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_magentic_one_group_chat_basic() -> None:
|
||||
async def test_magentic_one_group_chat_basic(runtime: AgentRuntime | None) -> None:
|
||||
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
||||
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
|
||||
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
||||
@@ -115,7 +129,9 @@ async def test_magentic_one_group_chat_basic() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
team = MagenticOneGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client)
|
||||
team = MagenticOneGroupChat(
|
||||
participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client, runtime=runtime
|
||||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 5
|
||||
assert result.messages[2].content == "Continue task"
|
||||
@@ -124,16 +140,18 @@ async def test_magentic_one_group_chat_basic() -> None:
|
||||
|
||||
# Test save and load.
|
||||
state = await team.save_state()
|
||||
team2 = MagenticOneGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client)
|
||||
team2 = MagenticOneGroupChat(
|
||||
participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client, runtime=runtime
|
||||
)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
manager_1 = await team._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team._team_id), # pyright: ignore
|
||||
AgentId(f"{team._group_chat_manager_name}_{team._team_id}", team._team_id), # pyright: ignore
|
||||
MagenticOneOrchestrator, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
||||
AgentId(f"{team2._group_chat_manager_name}_{team2._team_id}", team2._team_id), # pyright: ignore
|
||||
MagenticOneOrchestrator, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||
@@ -145,7 +163,7 @@ async def test_magentic_one_group_chat_basic() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_magentic_one_group_chat_with_stalls() -> None:
|
||||
async def test_magentic_one_group_chat_with_stalls(runtime: AgentRuntime | None) -> None:
|
||||
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
||||
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
|
||||
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
||||
@@ -189,7 +207,10 @@ async def test_magentic_one_group_chat_with_stalls() -> None:
|
||||
)
|
||||
|
||||
team = MagenticOneGroupChat(
|
||||
participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client, max_stalls=2
|
||||
participants=[agent_1, agent_2, agent_3, agent_4],
|
||||
model_client=model_client,
|
||||
max_stalls=2,
|
||||
runtime=runtime,
|
||||
)
|
||||
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
||||
assert len(result.messages) == 6
|
||||
|
||||
@@ -1,75 +1,34 @@
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, List
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent
|
||||
from autogen_agentchat.conditions import MaxMessageTermination
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from autogen_core import AgentRuntime, SingleThreadedAgentRuntime
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
|
||||
|
||||
class _MockChatCompletion:
|
||||
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
||||
self._saved_chat_completions = chat_completions
|
||||
self._curr_index = 0
|
||||
|
||||
async def mock_create(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
||||
await asyncio.sleep(0.1)
|
||||
completion = self._saved_chat_completions[self._curr_index]
|
||||
self._curr_index += 1
|
||||
return completion
|
||||
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
|
||||
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
|
||||
if request.param == "single_threaded":
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
runtime.start()
|
||||
yield runtime
|
||||
await runtime.stop()
|
||||
elif request.param == "embedded":
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_society_of_mind_agent(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="1", role="assistant"))
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="2", role="assistant"))
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="3", role="assistant"))
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="")
|
||||
|
||||
async def test_society_of_mind_agent(runtime: AgentRuntime | None) -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
["1", "2", "3"],
|
||||
)
|
||||
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
inner_termination = MaxMessageTermination(3)
|
||||
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
|
||||
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination, runtime=runtime)
|
||||
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
|
||||
response = await society_of_mind_agent.run(task="Count to 10.")
|
||||
assert len(response.messages) == 4
|
||||
@@ -84,14 +43,13 @@ async def test_society_of_mind_agent(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
inner_termination = MaxMessageTermination(3)
|
||||
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
|
||||
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination, runtime=runtime)
|
||||
society_of_mind_agent2 = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
|
||||
await society_of_mind_agent2.load_state(state)
|
||||
state2 = await society_of_mind_agent2.save_state()
|
||||
assert state == state2
|
||||
|
||||
# Test serialization.
|
||||
|
||||
soc_agent_config = society_of_mind_agent.dump_component()
|
||||
assert soc_agent_config.provider == "autogen_agentchat.agents.SocietyOfMindAgent"
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
from autogen_core import EVENT_LOGGER_NAME, CancellationToken
|
||||
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component
|
||||
from autogen_core.models import (
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
@@ -13,13 +14,22 @@ from autogen_core.models import (
|
||||
ModelFamily,
|
||||
ModelInfo,
|
||||
RequestUsage,
|
||||
validate_model_info,
|
||||
)
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class ReplayChatCompletionClient(ChatCompletionClient):
|
||||
class ReplayChatCompletionClientConfig(BaseModel):
|
||||
"""ReplayChatCompletionClient configuration."""
|
||||
|
||||
chat_completions: Sequence[Union[str, CreateResult]]
|
||||
model_info: Optional[ModelInfo] = None
|
||||
|
||||
|
||||
class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompletionClientConfig]):
|
||||
"""
|
||||
A mock chat completion client that replays predefined responses using an index-based approach.
|
||||
|
||||
@@ -111,25 +121,37 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
||||
"""
|
||||
|
||||
__protocol__: ChatCompletionClient
|
||||
component_type = "replay_chat_completion_client"
|
||||
component_provider_override = "autogen_ext.models.replay.ReplayChatCompletionClient"
|
||||
component_config_schema = ReplayChatCompletionClientConfig
|
||||
|
||||
# TODO: Support FunctionCall in responses
|
||||
# TODO: Support logprobs in Responses
|
||||
# TODO: Support model capabilities
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_completions: Sequence[Union[str, CreateResult]],
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
self.chat_completions = list(chat_completions)
|
||||
self.provided_message_count = len(self.chat_completions)
|
||||
self._model_info = ModelInfo(
|
||||
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
|
||||
)
|
||||
if model_info is not None:
|
||||
self._model_info = model_info
|
||||
validate_model_info(self._model_info)
|
||||
else:
|
||||
self._model_info = ModelInfo(
|
||||
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
|
||||
)
|
||||
self._total_available_tokens = 10000
|
||||
self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
self._current_index = 0
|
||||
self._cached_bool_value = True
|
||||
self._create_calls: List[Dict[str, Any]] = []
|
||||
|
||||
@property
|
||||
def create_calls(self) -> List[Dict[str, Any]]:
|
||||
"""Return the arguments of the calls made to the create method."""
|
||||
return self._create_calls
|
||||
|
||||
async def create(
|
||||
self,
|
||||
@@ -159,6 +181,15 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
self._update_total_usage()
|
||||
self._current_index += 1
|
||||
self._create_calls.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"json_output": json_output,
|
||||
"extra_create_args": extra_create_args,
|
||||
"cancellation_token": cancellation_token,
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
||||
async def create_stream(
|
||||
@@ -259,3 +290,16 @@ class ReplayChatCompletionClient(ChatCompletionClient):
|
||||
self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
self._current_index = 0
|
||||
|
||||
def _to_config(self) -> ReplayChatCompletionClientConfig:
|
||||
return ReplayChatCompletionClientConfig(
|
||||
chat_completions=self.chat_completions,
|
||||
model_info=self._model_info,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: ReplayChatCompletionClientConfig) -> Self:
|
||||
return cls(
|
||||
chat_completions=config.chat_completions,
|
||||
model_info=config.model_info,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user