mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
feat: add support for list of messages as team task input and update Society of Mind Agent (#4500)
* feat: add support for list of messages as team task input * Update society of mind agent to use the list input task --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Ryan Sweet <rysweet@microsoft.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args
|
||||
|
||||
from autogen_core import CancellationToken
|
||||
|
||||
from ..base import ChatAgent, Response, TaskResult
|
||||
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
from ..messages import (
|
||||
AgentMessage,
|
||||
ChatMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from ..state import BaseState
|
||||
|
||||
|
||||
@@ -45,8 +49,9 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[AgentMessage | Response, None]:
|
||||
"""Handles incoming messages and returns a stream of messages and
|
||||
and the final item is the response. The base implementation in :class:`BaseChatAgent`
|
||||
simply calls :meth:`on_messages` and yields the messages in the response."""
|
||||
and the final item is the response. The base implementation in
|
||||
:class:`BaseChatAgent` simply calls :meth:`on_messages` and yields
|
||||
the messages in the response."""
|
||||
response = await self.on_messages(messages, cancellation_token)
|
||||
for inner_message in response.inner_messages or []:
|
||||
yield inner_message
|
||||
@@ -55,7 +60,7 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | None = None,
|
||||
task: str | ChatMessage | List[ChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the agent with the given task and return the result."""
|
||||
@@ -69,7 +74,14 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
text_msg = TextMessage(content=task, source="user")
|
||||
input_messages.append(text_msg)
|
||||
output_messages.append(text_msg)
|
||||
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
|
||||
elif isinstance(task, list):
|
||||
for msg in task:
|
||||
if isinstance(msg, get_args(ChatMessage)[0]):
|
||||
input_messages.append(msg)
|
||||
output_messages.append(msg)
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in list: {type(msg)}")
|
||||
elif isinstance(task, get_args(ChatMessage)[0]):
|
||||
input_messages.append(task)
|
||||
output_messages.append(task)
|
||||
else:
|
||||
@@ -83,7 +95,7 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
async def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | None = None,
|
||||
task: str | ChatMessage | List[ChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
|
||||
"""Run the agent with the given task and return a stream of messages
|
||||
@@ -99,7 +111,15 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
input_messages.append(text_msg)
|
||||
output_messages.append(text_msg)
|
||||
yield text_msg
|
||||
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
|
||||
elif isinstance(task, list):
|
||||
for msg in task:
|
||||
if isinstance(msg, get_args(ChatMessage)[0]):
|
||||
input_messages.append(msg)
|
||||
output_messages.append(msg)
|
||||
yield msg
|
||||
else:
|
||||
raise ValueError(f"Invalid message type in list: {type(msg)}")
|
||||
elif isinstance(task, get_args(ChatMessage)[0]):
|
||||
input_messages.append(task)
|
||||
output_messages.append(task)
|
||||
yield task
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import AsyncGenerator, List, Sequence
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence
|
||||
|
||||
from autogen_core import CancellationToken, Image
|
||||
from autogen_core.models import ChatCompletionClient
|
||||
from autogen_core.models._types import SystemMessage
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
|
||||
|
||||
from autogen_agentchat.base import Response
|
||||
from autogen_agentchat.state import SocietyOfMindAgentState
|
||||
|
||||
from ..base import TaskResult, Team
|
||||
from ..messages import (
|
||||
@@ -32,6 +32,10 @@ class SocietyOfMindAgent(BaseChatAgent):
|
||||
team (Team): The team of agents to use.
|
||||
model_client (ChatCompletionClient): The model client to use for preparing responses.
|
||||
description (str, optional): The description of the agent.
|
||||
instruction (str, optional): The instruction to use when generating a response using the inner team's messages.
|
||||
Defaults to :attr:`DEFAULT_INSTRUCTION`. It assumes the role of 'system'.
|
||||
response_prompt (str, optional): The response prompt to use when generating a response using the inner team's messages.
|
||||
Defaults to :attr:`DEFAULT_RESPONSE_PROMPT`. It assumes the role of 'system'.
|
||||
|
||||
|
||||
Example:
|
||||
@@ -39,35 +43,51 @@ class SocietyOfMindAgent(BaseChatAgent):
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from autogen_agentchat.ui import Console
|
||||
from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||
from autogen_agentchat.conditions import MaxMessageTermination
|
||||
from autogen_agentchat.conditions import TextMentionTermination
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
inner_termination = MaxMessageTermination(3)
|
||||
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a writer, write well.")
|
||||
agent2 = AssistantAgent(
|
||||
"assistant2",
|
||||
model_client=model_client,
|
||||
system_message="You are an editor, provide critical feedback. Respond with 'APPROVE' if the text addresses all feedbacks.",
|
||||
)
|
||||
inner_termination = TextMentionTermination("APPROVE")
|
||||
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
|
||||
|
||||
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
|
||||
|
||||
agent3 = AssistantAgent("assistant3", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
agent4 = AssistantAgent("assistant4", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
outter_termination = MaxMessageTermination(10)
|
||||
team = RoundRobinGroupChat([society_of_mind_agent, agent3, agent4], termination_condition=outter_termination)
|
||||
agent3 = AssistantAgent(
|
||||
"assistant3", model_client=model_client, system_message="Translate the text to Spanish."
|
||||
)
|
||||
team = RoundRobinGroupChat([society_of_mind_agent, agent3], max_turns=2)
|
||||
|
||||
stream = team.run_stream(task="Tell me a one-liner joke.")
|
||||
async for message in stream:
|
||||
print(message)
|
||||
stream = team.run_stream(task="Write a short story with a surprising ending.")
|
||||
await Console(stream)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
DEFAULT_INSTRUCTION = "Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:"
|
||||
"""str: The default instruction to use when generating a response using the
|
||||
inner team's messages. The instruction will be prepended to the inner team's
|
||||
messages when generating a response using the model. It assumes the role of
|
||||
'system'."""
|
||||
|
||||
DEFAULT_RESPONSE_PROMPT = (
|
||||
"Output a standalone response to the original request, without mentioning any of the intermediate discussion."
|
||||
)
|
||||
"""str: The default response prompt to use when generating a response using
|
||||
the inner team's messages. It assumes the role of 'system'."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
@@ -75,17 +95,13 @@ class SocietyOfMindAgent(BaseChatAgent):
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
description: str = "An agent that uses an inner team of agents to generate responses.",
|
||||
task_prompt: str = "{transcript}\nContinue.",
|
||||
response_prompt: str = "Here is a transcript of conversation so far:\n{transcript}\n\\Provide a response to the original request.",
|
||||
instruction: str = DEFAULT_INSTRUCTION,
|
||||
response_prompt: str = DEFAULT_RESPONSE_PROMPT,
|
||||
) -> None:
|
||||
super().__init__(name=name, description=description)
|
||||
self._team = team
|
||||
self._model_client = model_client
|
||||
if "{transcript}" not in task_prompt:
|
||||
raise ValueError("The task prompt must contain the '{transcript}' placeholder for the transcript.")
|
||||
self._task_prompt = task_prompt
|
||||
if "{transcript}" not in response_prompt:
|
||||
raise ValueError("The response prompt must contain the '{transcript}' placeholder for the transcript.")
|
||||
self._instruction = instruction
|
||||
self._response_prompt = response_prompt
|
||||
|
||||
@property
|
||||
@@ -104,33 +120,41 @@ class SocietyOfMindAgent(BaseChatAgent):
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[AgentMessage | Response, None]:
|
||||
# Build the context.
|
||||
delta = list(messages)
|
||||
task: str | None = None
|
||||
if len(delta) > 0:
|
||||
task = self._task_prompt.format(transcript=self._create_transcript(delta))
|
||||
# Prepare the task for the team of agents.
|
||||
task = list(messages)
|
||||
|
||||
# Run the team of agents.
|
||||
result: TaskResult | None = None
|
||||
inner_messages: List[AgentMessage] = []
|
||||
count = 0
|
||||
async for inner_msg in self._team.run_stream(task=task, cancellation_token=cancellation_token):
|
||||
if isinstance(inner_msg, TaskResult):
|
||||
result = inner_msg
|
||||
else:
|
||||
count += 1
|
||||
if count <= len(task):
|
||||
# Skip the task messages.
|
||||
continue
|
||||
yield inner_msg
|
||||
inner_messages.append(inner_msg)
|
||||
assert result is not None
|
||||
|
||||
if len(inner_messages) < 2:
|
||||
# The first message is the task message so we need at least 2 messages.
|
||||
if len(inner_messages) == 0:
|
||||
yield Response(
|
||||
chat_message=TextMessage(source=self.name, content="No response."), inner_messages=inner_messages
|
||||
)
|
||||
else:
|
||||
prompt = self._response_prompt.format(transcript=self._create_transcript(inner_messages[1:]))
|
||||
completion = await self._model_client.create(
|
||||
messages=[SystemMessage(content=prompt)], cancellation_token=cancellation_token
|
||||
# Generate a response using the model client.
|
||||
llm_messages: List[LLMMessage] = [SystemMessage(content=self._instruction)]
|
||||
llm_messages.extend(
|
||||
[
|
||||
UserMessage(content=message.content, source=message.source)
|
||||
for message in inner_messages
|
||||
if isinstance(message, TextMessage | MultiModalMessage | StopMessage | HandoffMessage)
|
||||
]
|
||||
)
|
||||
llm_messages.append(SystemMessage(content=self._response_prompt))
|
||||
completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token)
|
||||
assert isinstance(completion.content, str)
|
||||
yield Response(
|
||||
chat_message=TextMessage(source=self.name, content=completion.content, models_usage=completion.usage),
|
||||
@@ -143,17 +167,11 @@ class SocietyOfMindAgent(BaseChatAgent):
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
await self._team.reset()
|
||||
|
||||
def _create_transcript(self, messages: Sequence[AgentMessage]) -> str:
|
||||
transcript = ""
|
||||
for message in messages:
|
||||
if isinstance(message, TextMessage | StopMessage | HandoffMessage):
|
||||
transcript += f"{message.source}: {message.content}\n"
|
||||
elif isinstance(message, MultiModalMessage):
|
||||
for content in message.content:
|
||||
if isinstance(content, Image):
|
||||
transcript += f"{message.source}: [Image]\n"
|
||||
else:
|
||||
transcript += f"{message.source}: {content}\n"
|
||||
else:
|
||||
raise ValueError(f"Unexpected message type: {message} in {self.__class__.__name__}")
|
||||
return transcript
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
team_state = await self._team.save_state()
|
||||
state = SocietyOfMindAgentState(inner_team_state=team_state)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
society_of_mind_state = SocietyOfMindAgentState.model_validate(state)
|
||||
await self._team.load_state(society_of_mind_state.inner_team_state)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, Protocol, Sequence
|
||||
from typing import AsyncGenerator, List, Protocol, Sequence
|
||||
|
||||
from autogen_core import CancellationToken
|
||||
|
||||
@@ -23,7 +23,7 @@ class TaskRunner(Protocol):
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | None = None,
|
||||
task: str | ChatMessage | List[ChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the task and return the result.
|
||||
@@ -36,7 +36,7 @@ class TaskRunner(Protocol):
|
||||
def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | None = None,
|
||||
task: str | ChatMessage | List[ChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
|
||||
"""Run the task and produces a stream of messages and the final result
|
||||
|
||||
@@ -8,6 +8,7 @@ from ._states import (
|
||||
MagenticOneOrchestratorState,
|
||||
RoundRobinManagerState,
|
||||
SelectorManagerState,
|
||||
SocietyOfMindAgentState,
|
||||
SwarmManagerState,
|
||||
TeamState,
|
||||
)
|
||||
@@ -22,4 +23,5 @@ __all__ = [
|
||||
"SwarmManagerState",
|
||||
"MagenticOneOrchestratorState",
|
||||
"TeamState",
|
||||
"SocietyOfMindAgentState",
|
||||
]
|
||||
|
||||
@@ -79,3 +79,10 @@ class MagenticOneOrchestratorState(BaseGroupChatManagerState):
|
||||
n_rounds: int = Field(default=0)
|
||||
n_stalls: int = Field(default=0)
|
||||
type: str = Field(default="MagenticOneOrchestratorState")
|
||||
|
||||
|
||||
class SocietyOfMindAgentState(BaseState):
|
||||
"""State for a Society of Mind agent."""
|
||||
|
||||
inner_team_state: Mapping[str, Any] = Field(default_factory=dict)
|
||||
type: str = Field(default="SocietyOfMindAgentState")
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Callable, List, Mapping
|
||||
from typing import Any, AsyncGenerator, Callable, List, Mapping, get_args
|
||||
|
||||
from autogen_core import (
|
||||
AgentId,
|
||||
@@ -19,7 +19,7 @@ from autogen_core._closure_agent import ClosureContext
|
||||
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
from ...messages import AgentMessage, ChatMessage, TextMessage
|
||||
from ...state import TeamState
|
||||
from ._chat_agent_container import ChatAgentContainer
|
||||
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
|
||||
@@ -146,11 +146,18 @@ class BaseGroupChat(Team, ABC):
|
||||
message: GroupChatStart | GroupChatMessage | GroupChatTermination,
|
||||
ctx: MessageContext,
|
||||
) -> None:
|
||||
event_logger.info(message.message)
|
||||
if isinstance(message, GroupChatTermination):
|
||||
"""Collect output messages from the group chat."""
|
||||
if isinstance(message, GroupChatStart):
|
||||
if message.messages is not None:
|
||||
for msg in message.messages:
|
||||
event_logger.info(msg)
|
||||
await self._output_message_queue.put(msg)
|
||||
elif isinstance(message, GroupChatMessage):
|
||||
event_logger.info(message.message)
|
||||
await self._output_message_queue.put(message.message)
|
||||
elif isinstance(message, GroupChatTermination):
|
||||
event_logger.info(message.message)
|
||||
self._stop_reason = message.message.content
|
||||
return
|
||||
await self._output_message_queue.put(message.message)
|
||||
|
||||
await ClosureAgent.register_closure(
|
||||
runtime,
|
||||
@@ -165,7 +172,7 @@ class BaseGroupChat(Team, ABC):
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | None = None,
|
||||
task: str | ChatMessage | List[ChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> TaskResult:
|
||||
"""Run the team and return the result. The base implementation uses
|
||||
@@ -173,7 +180,7 @@ class BaseGroupChat(Team, ABC):
|
||||
Once the team is stopped, the termination condition is reset.
|
||||
|
||||
Args:
|
||||
task (str | ChatMessage | None): The task to run the team with.
|
||||
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
|
||||
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
|
||||
Setting the cancellation token potentially put the team in an inconsistent state,
|
||||
and it may not reset the termination condition.
|
||||
@@ -264,7 +271,7 @@ class BaseGroupChat(Team, ABC):
|
||||
async def run_stream(
|
||||
self,
|
||||
*,
|
||||
task: str | ChatMessage | None = None,
|
||||
task: str | ChatMessage | List[ChatMessage] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
|
||||
"""Run the team and produces a stream of messages and the final result
|
||||
@@ -272,7 +279,7 @@ class BaseGroupChat(Team, ABC):
|
||||
team is stopped, the termination condition is reset.
|
||||
|
||||
Args:
|
||||
task (str | ChatMessage | None): The task to run the team with.
|
||||
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
|
||||
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
|
||||
Setting the cancellation token potentially put the team in an inconsistent state,
|
||||
and it may not reset the termination condition.
|
||||
@@ -355,16 +362,20 @@ class BaseGroupChat(Team, ABC):
|
||||
|
||||
"""
|
||||
|
||||
# Create the first chat message if the task is a string or a chat message.
|
||||
first_chat_message: ChatMessage | None = None
|
||||
# Create the messages list if the task is a string or a chat message.
|
||||
messages: List[ChatMessage] | None = None
|
||||
if task is None:
|
||||
pass
|
||||
elif isinstance(task, str):
|
||||
first_chat_message = TextMessage(content=task, source="user")
|
||||
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
|
||||
first_chat_message = task
|
||||
else:
|
||||
raise ValueError(f"Invalid task type: {type(task)}")
|
||||
messages = [TextMessage(content=task, source="user")]
|
||||
elif isinstance(task, get_args(ChatMessage)[0]):
|
||||
messages = [task] # type: ignore
|
||||
elif isinstance(task, list):
|
||||
if not task:
|
||||
raise ValueError("Task list cannot be empty")
|
||||
if not all(isinstance(msg, get_args(ChatMessage)[0]) for msg in task):
|
||||
raise ValueError("All messages in task list must be valid ChatMessage types")
|
||||
messages = task
|
||||
|
||||
if self._is_running:
|
||||
raise ValueError("The team is already running, it cannot run again until it is stopped.")
|
||||
@@ -389,7 +400,7 @@ class BaseGroupChat(Team, ABC):
|
||||
# The group chat manager will start the group chat by relaying the message to the participants
|
||||
# and the closure agent.
|
||||
await self._runtime.send_message(
|
||||
GroupChatStart(message=first_chat_message),
|
||||
GroupChatStart(messages=messages),
|
||||
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
@@ -70,24 +70,28 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Validate the group state given the start message.
|
||||
await self.validate_group_state(message.message)
|
||||
# Validate the group state given the start messages
|
||||
await self.validate_group_state(message.messages)
|
||||
|
||||
if message.message is not None:
|
||||
# Log the start message.
|
||||
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||
|
||||
# Relay the start message to the participants.
|
||||
if message.messages is not None:
|
||||
# Log all messages at once
|
||||
await self.publish_message(
|
||||
message, topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=ctx.cancellation_token
|
||||
GroupChatStart(messages=message.messages), topic_id=DefaultTopicId(type=self._output_topic_type)
|
||||
)
|
||||
|
||||
# Append the user message to the message thread.
|
||||
self._message_thread.append(message.message)
|
||||
# Relay all messages at once to participants
|
||||
await self.publish_message(
|
||||
GroupChatStart(messages=message.messages),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
|
||||
# Check if the conversation should be terminated.
|
||||
# Append all messages to thread
|
||||
self._message_thread.extend(message.messages)
|
||||
|
||||
# Check termination condition after processing all messages
|
||||
if self._termination_condition is not None:
|
||||
stop_message = await self._termination_condition([message.message])
|
||||
stop_message = await self._termination_condition(message.messages)
|
||||
if stop_message is not None:
|
||||
await self.publish_message(
|
||||
GroupChatTermination(message=stop_message),
|
||||
@@ -97,7 +101,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
await self._termination_condition.reset()
|
||||
return
|
||||
|
||||
# Select a speaker to start the conversation.
|
||||
# Select a speaker to start/continue the conversation
|
||||
speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
|
||||
# Link the select speaker future to the cancellation token.
|
||||
ctx.cancellation_token.link_future(speaker_topic_type_future)
|
||||
@@ -166,8 +170,13 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
await self.reset()
|
||||
|
||||
@abstractmethod
|
||||
async def validate_group_state(self, message: ChatMessage | None) -> None:
|
||||
"""Validate the state of the group chat given the start message. This is executed when the group chat manager receives a GroupChatStart event."""
|
||||
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
|
||||
"""Validate the state of the group chat given the start messages.
|
||||
This is executed when the group chat manager receives a GroupChatStart event.
|
||||
|
||||
Args:
|
||||
messages: A list of chat messages to validate, or None if no messages are provided.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -30,8 +30,8 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
||||
@event
|
||||
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
|
||||
"""Handle a start event by appending the content to the buffer."""
|
||||
if message.message is not None:
|
||||
self._message_buffer.append(message.message)
|
||||
if message.messages is not None:
|
||||
self._message_buffer.extend(message.messages)
|
||||
|
||||
@event
|
||||
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...base import Response
|
||||
@@ -7,8 +9,8 @@ from ...messages import AgentMessage, ChatMessage, StopMessage
|
||||
class GroupChatStart(BaseModel):
|
||||
"""A request to start a group chat."""
|
||||
|
||||
message: ChatMessage | None = None
|
||||
"""An optional user message to start the group chat."""
|
||||
messages: List[ChatMessage] | None = None
|
||||
"""An optional list of messages to start the group chat."""
|
||||
|
||||
|
||||
class GroupChatAgentResponse(BaseModel):
|
||||
|
||||
@@ -126,17 +126,18 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
)
|
||||
# Stop the group chat.
|
||||
return
|
||||
assert message is not None and message.message is not None
|
||||
assert message is not None and message.messages is not None
|
||||
|
||||
# Validate the group state given the start message.
|
||||
await self.validate_group_state(message.message)
|
||||
# Validate the group state given all the messages.
|
||||
await self.validate_group_state(message.messages)
|
||||
|
||||
# Log the start message.
|
||||
# Log the message.
|
||||
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))
|
||||
# Outer Loop for first time
|
||||
# Create the initial task ledger
|
||||
#################################
|
||||
self._task = self._content_to_str(message.message.content)
|
||||
# Combine all message contents for task
|
||||
self._task = " ".join([self._content_to_str(msg.content) for msg in message.messages])
|
||||
planning_conversation: List[LLMMessage] = []
|
||||
|
||||
# 1. GATHER FACTS
|
||||
@@ -184,7 +185,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
return
|
||||
await self._orchestrate_step(ctx.cancellation_token)
|
||||
|
||||
async def validate_group_state(self, message: ChatMessage | None) -> None:
|
||||
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
|
||||
@@ -29,7 +29,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
||||
)
|
||||
self._next_speaker_index = 0
|
||||
|
||||
async def validate_group_state(self, message: ChatMessage | None) -> None:
|
||||
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def reset(self) -> None:
|
||||
|
||||
@@ -54,7 +54,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
self._allow_repeated_speaker = allow_repeated_speaker
|
||||
self._selector_func = selector_func
|
||||
|
||||
async def validate_group_state(self, message: ChatMessage | None) -> None:
|
||||
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
|
||||
pass
|
||||
|
||||
async def reset(self) -> None:
|
||||
|
||||
@@ -29,16 +29,19 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
)
|
||||
self._current_speaker = participant_topic_types[0]
|
||||
|
||||
async def validate_group_state(self, message: ChatMessage | None) -> None:
|
||||
"""Validate the start message for the group chat."""
|
||||
# Check if the start message is a handoff message.
|
||||
if isinstance(message, HandoffMessage):
|
||||
if message.target not in self._participant_topic_types:
|
||||
raise ValueError(
|
||||
f"The target {message.target} is not one of the participants {self._participant_topic_types}. "
|
||||
"If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target."
|
||||
)
|
||||
return
|
||||
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
|
||||
"""Validate the start messages for the group chat."""
|
||||
# Check if any of the start messages is a handoff message.
|
||||
if messages:
|
||||
for message in messages:
|
||||
if isinstance(message, HandoffMessage):
|
||||
if message.target not in self._participant_topic_types:
|
||||
raise ValueError(
|
||||
f"The target {message.target} is not one of the participants {self._participant_topic_types}. "
|
||||
"If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target."
|
||||
)
|
||||
return
|
||||
|
||||
# Check if there is a handoff message in the thread that is not targeting a valid participant.
|
||||
for existing_message in reversed(self._message_thread):
|
||||
if isinstance(existing_message, HandoffMessage):
|
||||
|
||||
@@ -8,6 +8,7 @@ from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.base import Handoff, TaskResult
|
||||
from autogen_agentchat.messages import (
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
TextMessage,
|
||||
@@ -21,7 +22,10 @@ from openai.resources.chat.completions import AsyncCompletions
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from utils import FileLogHandler
|
||||
|
||||
@@ -33,14 +37,14 @@ logger.addHandler(FileLogHandler("test_assistant_agent.log"))
|
||||
class _MockChatCompletion:
|
||||
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
||||
self._saved_chat_completions = chat_completions
|
||||
self._curr_index = 0
|
||||
self.curr_index = 0
|
||||
|
||||
async def mock_create(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
||||
await asyncio.sleep(0.1)
|
||||
completion = self._saved_chat_completions[self._curr_index]
|
||||
self._curr_index += 1
|
||||
completion = self._saved_chat_completions[self.curr_index]
|
||||
self.curr_index += 1
|
||||
return completion
|
||||
|
||||
|
||||
@@ -90,7 +94,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="pass", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
@@ -101,7 +109,9 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
@@ -115,7 +125,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
agent = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
FunctionTool(_echo_function, description="Echo"),
|
||||
],
|
||||
)
|
||||
result = await agent.run(task="task")
|
||||
|
||||
@@ -133,14 +147,14 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert result.messages[3].models_usage is None
|
||||
|
||||
# Test streaming.
|
||||
mock._curr_index = 0 # pyright: ignore
|
||||
mock.curr_index = 0 # Reset the mock
|
||||
index = 0
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
index += 1
|
||||
|
||||
# Test state saving and loading.
|
||||
state = await agent.save_state()
|
||||
@@ -234,7 +248,7 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) ->
|
||||
assert result.messages[3].models_usage.prompt_tokens == 10
|
||||
|
||||
# Test streaming.
|
||||
mock._curr_index = 0 # pyright: ignore
|
||||
mock.curr_index = 0 # pyright: ignore
|
||||
index = 0
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
@@ -248,7 +262,11 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) ->
|
||||
agent2 = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
FunctionTool(_echo_function, description="Echo"),
|
||||
],
|
||||
)
|
||||
await agent2.load_state(state)
|
||||
state2 = await agent2.save_state()
|
||||
@@ -293,7 +311,11 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
tool_use_agent = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
FunctionTool(_echo_function, description="Echo"),
|
||||
],
|
||||
handoffs=[handoff],
|
||||
)
|
||||
assert HandoffMessage in tool_use_agent.produced_message_types
|
||||
@@ -313,7 +335,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert result.messages[3].models_usage is None
|
||||
|
||||
# Test streaming.
|
||||
mock._curr_index = 0 # pyright: ignore
|
||||
mock.curr_index = 0 # pyright: ignore
|
||||
index = 0
|
||||
async for message in tool_use_agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
@@ -330,7 +352,11 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="Hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
@@ -340,7 +366,10 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
agent = AssistantAgent(name="assistant", model_client=OpenAIChatCompletionClient(model=model, api_key=""))
|
||||
agent = AssistantAgent(
|
||||
name="assistant",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
)
|
||||
# Generate a random base64 image.
|
||||
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
|
||||
result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
|
||||
@@ -351,14 +380,24 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def test_invalid_model_capabilities() -> None:
|
||||
model = "random-model"
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model=model, api_key="", model_capabilities={"vision": False, "function_calling": False, "json_output": False}
|
||||
model=model,
|
||||
api_key="",
|
||||
model_capabilities={
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
agent = AssistantAgent(
|
||||
name="assistant",
|
||||
model_client=model_client,
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
tools=[
|
||||
_pass_function,
|
||||
_fail_function,
|
||||
FunctionTool(_echo_function, description="Echo"),
|
||||
],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
@@ -369,3 +408,62 @@ async def test_invalid_model_capabilities() -> None:
|
||||
# Generate a random base64 image.
|
||||
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
|
||||
await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(content="Response to message 1", role="assistant"),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
agent = AssistantAgent(
|
||||
"test_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
)
|
||||
|
||||
# Create a list of chat messages
|
||||
messages: List[ChatMessage] = [
|
||||
TextMessage(content="Message 1", source="user"),
|
||||
TextMessage(content="Message 2", source="user"),
|
||||
]
|
||||
|
||||
# Test run method with list of messages
|
||||
result = await agent.run(task=messages)
|
||||
assert len(result.messages) == 3 # 2 input messages + 1 response message
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].content == "Message 1"
|
||||
assert result.messages[0].source == "user"
|
||||
assert isinstance(result.messages[1], TextMessage)
|
||||
assert result.messages[1].content == "Message 2"
|
||||
assert result.messages[1].source == "user"
|
||||
assert isinstance(result.messages[2], TextMessage)
|
||||
assert result.messages[2].content == "Response to message 1"
|
||||
assert result.messages[2].source == "test_agent"
|
||||
assert result.messages[2].models_usage is not None
|
||||
assert result.messages[2].models_usage.completion_tokens == 5
|
||||
assert result.messages[2].models_usage.prompt_tokens == 10
|
||||
|
||||
# Test run_stream method with list of messages
|
||||
mock.curr_index = 0 # Reset mock index using public attribute
|
||||
index = 0
|
||||
async for message in agent.run_stream(task=messages):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
@@ -1025,3 +1025,48 @@ async def test_swarm_with_handoff_termination() -> None:
|
||||
assert result.messages[1].content == "Transferred to second_agent."
|
||||
assert result.messages[2].content == "Transferred to third_agent."
|
||||
assert result.messages[3].content == "Transferred to non_existing_agent."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_with_message_list() -> None:
|
||||
# Create a simple team with echo agents
|
||||
agent1 = _EchoAgent("Agent1", "First agent")
|
||||
agent2 = _EchoAgent("Agent2", "Second agent")
|
||||
termination = MaxMessageTermination(4) # Stop after 4 messages
|
||||
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
||||
|
||||
# Create a list of messages
|
||||
messages: List[ChatMessage] = [
|
||||
TextMessage(content="Message 1", source="user"),
|
||||
TextMessage(content="Message 2", source="user"),
|
||||
TextMessage(content="Message 3", source="user"),
|
||||
]
|
||||
|
||||
# Run the team with the message list
|
||||
result = await team.run(task=messages)
|
||||
|
||||
# Verify the messages were processed in order
|
||||
assert len(result.messages) == 4 # Initial messages + echo until termination
|
||||
assert result.messages[0].content == "Message 1" # First message
|
||||
assert result.messages[1].content == "Message 2" # Second message
|
||||
assert result.messages[2].content == "Message 3" # Third message
|
||||
assert result.messages[3].content == "Message 1" # Echo from first agent
|
||||
assert result.stop_reason == "Maximum number of messages 4 reached, current message count: 4"
|
||||
|
||||
# Test with streaming
|
||||
await team.reset()
|
||||
index = 0
|
||||
async for message in team.run_stream(task=messages):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
# Test with invalid message list
|
||||
with pytest.raises(ValueError, match="All messages in task list must be valid ChatMessage types"):
|
||||
await team.run(task=["not a message"]) # type: ignore[list-item, arg-type] # intentionally testing invalid input
|
||||
|
||||
# Test with empty message list
|
||||
with pytest.raises(ValueError, match="Task list cannot be empty"):
|
||||
await team.run(task=[])
|
||||
|
||||
@@ -72,9 +72,20 @@ async def test_society_of_mind_agent(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
|
||||
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
|
||||
response = await society_of_mind_agent.run(task="Count to 10.")
|
||||
assert len(response.messages) == 5
|
||||
assert len(response.messages) == 4
|
||||
assert response.messages[0].source == "user"
|
||||
assert response.messages[1].source == "user"
|
||||
assert response.messages[2].source == "assistant1"
|
||||
assert response.messages[3].source == "assistant2"
|
||||
assert response.messages[4].source == "society_of_mind"
|
||||
assert response.messages[1].source == "assistant1"
|
||||
assert response.messages[2].source == "assistant2"
|
||||
assert response.messages[3].source == "society_of_mind"
|
||||
|
||||
# Test save and load state.
|
||||
state = await society_of_mind_agent.save_state()
|
||||
assert state is not None
|
||||
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
inner_termination = MaxMessageTermination(3)
|
||||
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
|
||||
society_of_mind_agent2 = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
|
||||
await society_of_mind_agent2.load_state(state)
|
||||
state2 = await society_of_mind_agent2.save_state()
|
||||
assert state == state2
|
||||
|
||||
Reference in New Issue
Block a user