mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
SocietyOfMind agent for nested teams (#4110)
* Initial implementation of SOM agent * add tests * edit prompt * Update prompt * lint
This commit is contained in:
@@ -2,6 +2,7 @@ from ._assistant_agent import AssistantAgent, Handoff
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
from ._code_executor_agent import CodeExecutorAgent
|
||||
from ._coding_assistant_agent import CodingAssistantAgent
|
||||
from ._society_of_mind_agent import SocietyOfMindAgent
|
||||
from ._tool_use_assistant_agent import ToolUseAssistantAgent
|
||||
|
||||
__all__ = [
|
||||
@@ -11,4 +12,5 @@ __all__ = [
|
||||
"CodeExecutorAgent",
|
||||
"CodingAssistantAgent",
|
||||
"ToolUseAssistantAgent",
|
||||
"SocietyOfMindAgent",
|
||||
]
|
||||
|
||||
@@ -20,9 +20,9 @@ from pydantic import BaseModel, Field, model_validator
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
from ..base import Response
|
||||
from ..messages import (
|
||||
AgentMessage,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
InnerMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessage,
|
||||
@@ -217,13 +217,13 @@ class AssistantAgent(BaseChatAgent):
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[InnerMessage | Response, None]:
|
||||
) -> AsyncGenerator[AgentMessage | Response, None]:
|
||||
# Add messages to the model context.
|
||||
for msg in messages:
|
||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||
|
||||
# Inner messages.
|
||||
inner_messages: List[InnerMessage] = []
|
||||
inner_messages: List[AgentMessage] = []
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
llm_messages = self._system_messages + self._model_context
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import AsyncGenerator, List, Sequence
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..base import ChatAgent, Response, TaskResult
|
||||
from ..messages import AgentMessage, ChatMessage, InnerMessage, MultiModalMessage, TextMessage
|
||||
from ..messages import AgentMessage, ChatMessage, MultiModalMessage, TextMessage
|
||||
|
||||
|
||||
class BaseChatAgent(ChatAgent, ABC):
|
||||
@@ -42,7 +42,7 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[InnerMessage | Response, None]:
|
||||
) -> AsyncGenerator[AgentMessage | Response, None]:
|
||||
"""Handles incoming messages and returns a stream of messages and
|
||||
and the final item is the response. The base implementation in :class:`BaseChatAgent`
|
||||
simply calls :meth:`on_messages` and yields the messages in the response."""
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
from typing import AsyncGenerator, List, Sequence
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components import Image
|
||||
from autogen_core.components.models import ChatCompletionClient
|
||||
from autogen_core.components.models._types import SystemMessage
|
||||
|
||||
from autogen_agentchat.base import Response
|
||||
|
||||
from ..base import TaskResult, Team
|
||||
from ..messages import (
|
||||
AgentMessage,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
|
||||
class SocietyOfMindAgent(BaseChatAgent):
|
||||
"""An agent that uses an inner team of agents to generate responses.
|
||||
|
||||
Each time the agent's :meth:`on_messages` or :meth:`on_messages_stream`
|
||||
method is called, it runs the inner team of agents and then uses the
|
||||
model client to generate a response based on the inner team's messages.
|
||||
Once the response is generated, the agent resets the inner team by
|
||||
calling :meth:`Team.reset`.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
team (Team): The team of agents to use.
|
||||
model_client (ChatCompletionClient): The model client to use for preparing responses.
|
||||
description (str, optional): The description of the agent.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||
from autogen_agentchat.task import MaxMessageTermination
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o")
|
||||
|
||||
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
inner_termination = MaxMessageTermination(3)
|
||||
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
|
||||
|
||||
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
|
||||
|
||||
agent3 = AssistantAgent("assistant3", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
agent4 = AssistantAgent("assistant4", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
outter_termination = MaxMessageTermination(10)
|
||||
team = RoundRobinGroupChat([society_of_mind_agent, agent3, agent4], termination_condition=outter_termination)
|
||||
|
||||
stream = team.run_stream(task="Tell me a one-liner joke.")
|
||||
async for message in stream:
|
||||
print(message)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
team: Team,
|
||||
model_client: ChatCompletionClient,
|
||||
*,
|
||||
description: str = "An agent that uses an inner team of agents to generate responses.",
|
||||
task_prompt: str = "{transcript}\nContinue.",
|
||||
response_prompt: str = "Here is a transcript of conversation so far:\n{transcript}\n\\Provide a response to the original request.",
|
||||
) -> None:
|
||||
super().__init__(name=name, description=description)
|
||||
self._team = team
|
||||
self._model_client = model_client
|
||||
if "{transcript}" not in task_prompt:
|
||||
raise ValueError("The task prompt must contain the '{transcript}' placeholder for the transcript.")
|
||||
self._task_prompt = task_prompt
|
||||
if "{transcript}" not in response_prompt:
|
||||
raise ValueError("The response prompt must contain the '{transcript}' placeholder for the transcript.")
|
||||
self._response_prompt = response_prompt
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
return [TextMessage]
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
||||
# Call the stream method and collect the messages.
|
||||
response: Response | None = None
|
||||
async for msg in self.on_messages_stream(messages, cancellation_token):
|
||||
if isinstance(msg, Response):
|
||||
response = msg
|
||||
assert response is not None
|
||||
return response
|
||||
|
||||
async def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[AgentMessage | Response, None]:
|
||||
# Build the context.
|
||||
delta = list(messages)
|
||||
task: str | None = None
|
||||
if len(delta) > 0:
|
||||
task = self._task_prompt.format(transcript=self._create_transcript(delta))
|
||||
|
||||
# Run the team of agents.
|
||||
result: TaskResult | None = None
|
||||
inner_messages: List[AgentMessage] = []
|
||||
async for inner_msg in self._team.run_stream(task=task, cancellation_token=cancellation_token):
|
||||
if isinstance(inner_msg, TaskResult):
|
||||
result = inner_msg
|
||||
else:
|
||||
yield inner_msg
|
||||
inner_messages.append(inner_msg)
|
||||
assert result is not None
|
||||
|
||||
if len(inner_messages) < 2:
|
||||
# The first message is the task message so we need at least 2 messages.
|
||||
yield Response(
|
||||
chat_message=TextMessage(source=self.name, content="No response."), inner_messages=inner_messages
|
||||
)
|
||||
else:
|
||||
prompt = self._response_prompt.format(transcript=self._create_transcript(inner_messages[1:]))
|
||||
completion = await self._model_client.create(
|
||||
messages=[SystemMessage(content=prompt)], cancellation_token=cancellation_token
|
||||
)
|
||||
assert isinstance(completion.content, str)
|
||||
yield Response(
|
||||
chat_message=TextMessage(source=self.name, content=completion.content, models_usage=completion.usage),
|
||||
inner_messages=inner_messages,
|
||||
)
|
||||
|
||||
# Reset the team.
|
||||
await self._team.reset()
|
||||
|
||||
async def reset(self, cancellation_token: CancellationToken) -> None:
|
||||
await self._team.reset()
|
||||
|
||||
def _create_transcript(self, messages: Sequence[AgentMessage]) -> str:
|
||||
transcript = ""
|
||||
for message in messages:
|
||||
if isinstance(message, TextMessage | StopMessage | HandoffMessage):
|
||||
transcript += f"{message.source}: {message.content}\n"
|
||||
elif isinstance(message, MultiModalMessage):
|
||||
for content in message.content:
|
||||
if isinstance(content, Image):
|
||||
transcript += f"{message.source}: [Image]\n"
|
||||
else:
|
||||
transcript += f"{message.source}: {content}\n"
|
||||
else:
|
||||
raise ValueError(f"Unexpected message type: {message} in {self.__class__.__name__}")
|
||||
return transcript
|
||||
@@ -3,7 +3,7 @@ from typing import AsyncGenerator, List, Protocol, Sequence, runtime_checkable
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
|
||||
from ..messages import ChatMessage, InnerMessage
|
||||
from ..messages import AgentMessage, ChatMessage
|
||||
from ._task import TaskRunner
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class Response:
|
||||
chat_message: ChatMessage
|
||||
"""A chat message produced by the agent as the response."""
|
||||
|
||||
inner_messages: List[InnerMessage] | None = None
|
||||
inner_messages: List[AgentMessage] | None = None
|
||||
"""Inner messages produced by the agent."""
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class ChatAgent(TaskRunner, Protocol):
|
||||
|
||||
def on_messages_stream(
|
||||
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
|
||||
) -> AsyncGenerator[InnerMessage | Response, None]:
|
||||
) -> AsyncGenerator[AgentMessage | Response, None]:
|
||||
"""Handles incoming messages and returns a stream of inner messages and
|
||||
and the final item is the response."""
|
||||
...
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, List
|
||||
|
||||
import pytest
|
||||
from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent
|
||||
from autogen_agentchat.task import MaxMessageTermination
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
|
||||
class _MockChatCompletion:
|
||||
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
||||
self._saved_chat_completions = chat_completions
|
||||
self._curr_index = 0
|
||||
|
||||
async def mock_create(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
||||
await asyncio.sleep(0.1)
|
||||
completion = self._saved_chat_completions[self._curr_index]
|
||||
self._curr_index += 1
|
||||
return completion
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_society_of_mind_agent(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="1", role="assistant"))
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="2", role="assistant"))
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="3", role="assistant"))
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="")
|
||||
|
||||
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
|
||||
inner_termination = MaxMessageTermination(3)
|
||||
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
|
||||
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
|
||||
response = await society_of_mind_agent.run(task="Count to 10.")
|
||||
assert len(response.messages) == 5
|
||||
assert response.messages[0].source == "user"
|
||||
assert response.messages[1].source == "user"
|
||||
assert response.messages[2].source == "assistant1"
|
||||
assert response.messages[3].source == "assistant2"
|
||||
assert response.messages[4].source == "society_of_mind"
|
||||
Reference in New Issue
Block a user