mirror of
https://github.com/microsoft/autogen.git
synced 2026-05-13 03:00:55 -04:00
Add examples for mixture of agents; patch doc strings (#108)
* add examples for mixture of agents * format
This commit is contained in:
@@ -41,6 +41,10 @@ custom agents and message types for building applications.
|
||||
- `inner_outer.py`: An example of how to create an inner and outer custom agent.
|
||||
- `chat_room.py`: An example of how to create a chat room of custom agents without
|
||||
a centralized orchestrator.
|
||||
- `mixture_of_agents_pub_sub.py`: An example of how to create [a mixture of agents](https://github.com/togethercomputer/moa)
|
||||
that communicate using a publish-subscribe pattern.
|
||||
- `mixture_of_agents_gather.py`: An example of how to create [a mixture of agents](https://github.com/togethercomputer/moa)
|
||||
that communicate using an async distribute-gather pattern.
|
||||
|
||||
## Running the examples
|
||||
|
||||
|
||||
140
python/examples/mixture_of_agents_gather.py
Normal file
140
python/examples/mixture_of_agents_gather.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""This example demonstrates the mixture of agents implemented using direct
|
||||
messaging and async gathering of results.
|
||||
Mixture of agents: https://github.com/togethercomputer/moa"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components.models import ChatCompletionClient, OpenAI, SystemMessage, UserMessage
|
||||
from agnext.core import AgentId, CancellationToken
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceAgentTask:
|
||||
task: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceAgentTaskResult:
|
||||
result: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatorTask:
|
||||
task: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatorTaskResult:
|
||||
result: str
|
||||
|
||||
|
||||
class ReferenceAgent(TypeRoutedAgent):
|
||||
"""The reference agent that handles each task independently."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
system_messages: List[SystemMessage],
|
||||
model_client: ChatCompletionClient,
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._system_messages = system_messages
|
||||
self._model_client = model_client
|
||||
|
||||
@message_handler
|
||||
async def handle_task(
|
||||
self, message: ReferenceAgentTask, cancellation_token: CancellationToken
|
||||
) -> ReferenceAgentTaskResult:
|
||||
"""Handle a task message. This method sends the task to the model and respond with the result."""
|
||||
task_message = UserMessage(content=message.task, source=self.metadata["name"])
|
||||
response = await self._model_client.create(self._system_messages + [task_message])
|
||||
assert isinstance(response.content, str)
|
||||
return ReferenceAgentTaskResult(result=response.content)
|
||||
|
||||
|
||||
class AggregatorAgent(TypeRoutedAgent):
|
||||
"""The aggregator agent that distribute tasks to reference agents and aggregates the results."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
system_messages: List[SystemMessage],
|
||||
model_client: ChatCompletionClient,
|
||||
references: List[AgentId],
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._system_messages = system_messages
|
||||
self._model_client = model_client
|
||||
self._references = references
|
||||
|
||||
@message_handler
|
||||
async def handle_task(self, message: AggregatorTask, cancellation_token: CancellationToken) -> AggregatorTaskResult:
|
||||
"""Handle a task message. This method sends the task to the reference agents
|
||||
and aggregates the results."""
|
||||
ref_task = ReferenceAgentTask(task=message.task)
|
||||
results: List[ReferenceAgentTaskResult] = await asyncio.gather(
|
||||
*[self.send_message(ref_task, ref) for ref in self._references]
|
||||
)
|
||||
combined_result = "\n\n".join([r.result for r in results])
|
||||
response = await self._model_client.create(
|
||||
self._system_messages + [UserMessage(content=combined_result, source=self.metadata["name"])]
|
||||
)
|
||||
assert isinstance(response.content, str)
|
||||
return AggregatorTaskResult(result=response.content)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
ref1 = runtime.register_and_get(
|
||||
"ReferenceAgent1",
|
||||
lambda: ReferenceAgent(
|
||||
description="Reference Agent 1",
|
||||
system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo", temperature=0.1),
|
||||
),
|
||||
)
|
||||
ref2 = runtime.register_and_get(
|
||||
"ReferenceAgent2",
|
||||
lambda: ReferenceAgent(
|
||||
description="Reference Agent 2",
|
||||
system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo", temperature=0.5),
|
||||
),
|
||||
)
|
||||
ref3 = runtime.register_and_get(
|
||||
"ReferenceAgent3",
|
||||
lambda: ReferenceAgent(
|
||||
description="Reference Agent 3",
|
||||
system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo", temperature=1.0),
|
||||
),
|
||||
)
|
||||
agg = runtime.register_and_get(
|
||||
"AggregatorAgent",
|
||||
lambda: AggregatorAgent(
|
||||
description="Aggregator Agent",
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"...synthesize these responses into a single, high-quality response... Responses from models:"
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo"),
|
||||
references=[ref1, ref2, ref3],
|
||||
),
|
||||
)
|
||||
result = runtime.send_message(AggregatorTask(task="What are something fun to do in SF?"), agg)
|
||||
while result.done() is False:
|
||||
await runtime.process_next()
|
||||
print(result.result())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
asyncio.run(main())
|
||||
148
python/examples/mixture_of_agents_pub_sub.py
Normal file
148
python/examples/mixture_of_agents_pub_sub.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""This example demonstrates the mixture of agents implemented using pub/sub messaging.
|
||||
Mixture of agents: https://github.com/togethercomputer/moa"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.components.models import ChatCompletionClient, OpenAI, SystemMessage, UserMessage
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceAgentTask:
|
||||
session_id: str
|
||||
task: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceAgentTaskResult:
|
||||
session_id: str
|
||||
result: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatorTask:
|
||||
task: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatorTaskResult:
|
||||
result: str
|
||||
|
||||
|
||||
class ReferenceAgent(TypeRoutedAgent):
|
||||
"""The reference agent that handles each task independently."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
system_messages: List[SystemMessage],
|
||||
model_client: ChatCompletionClient,
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._system_messages = system_messages
|
||||
self._model_client = model_client
|
||||
|
||||
@message_handler
|
||||
async def handle_task(self, message: ReferenceAgentTask, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a task message. This method sends the task to the model and publishes the result."""
|
||||
task_message = UserMessage(content=message.task, source=self.metadata["name"])
|
||||
response = await self._model_client.create(self._system_messages + [task_message])
|
||||
assert isinstance(response.content, str)
|
||||
task_result = ReferenceAgentTaskResult(session_id=message.session_id, result=response.content)
|
||||
await self.publish_message(task_result)
|
||||
|
||||
|
||||
class AggregatorAgent(TypeRoutedAgent):
|
||||
"""The aggregator agent that distribute tasks to reference agents and aggregates the results."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
system_messages: List[SystemMessage],
|
||||
model_client: ChatCompletionClient,
|
||||
num_references: int,
|
||||
) -> None:
|
||||
super().__init__(description)
|
||||
self._system_messages = system_messages
|
||||
self._model_client = model_client
|
||||
self._num_references = num_references
|
||||
self._session_results: Dict[str, List[ReferenceAgentTaskResult]] = {}
|
||||
|
||||
@message_handler
|
||||
async def handle_task(self, message: AggregatorTask, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a task message. This method publishes the task to the reference agents."""
|
||||
session_id = str(uuid.uuid4())
|
||||
ref_task = ReferenceAgentTask(session_id=session_id, task=message.task)
|
||||
await self.publish_message(ref_task)
|
||||
|
||||
@message_handler
|
||||
async def handle_result(self, message: ReferenceAgentTaskResult, cancellation_token: CancellationToken) -> None:
|
||||
"""Handle a task result message. Once all results are received, this method
|
||||
aggregates the results and publishes the final result."""
|
||||
self._session_results.setdefault(message.session_id, []).append(message)
|
||||
if len(self._session_results[message.session_id]) == self._num_references:
|
||||
result = "\n\n".join([r.result for r in self._session_results[message.session_id]])
|
||||
response = await self._model_client.create(
|
||||
self._system_messages + [UserMessage(content=result, source=self.metadata["name"])]
|
||||
)
|
||||
assert isinstance(response.content, str)
|
||||
task_result = AggregatorTaskResult(result=response.content)
|
||||
await self.publish_message(task_result)
|
||||
self._session_results.pop(message.session_id)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
runtime.register(
|
||||
"ReferenceAgent1",
|
||||
lambda: ReferenceAgent(
|
||||
description="Reference Agent 1",
|
||||
system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo", temperature=0.1),
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
"ReferenceAgent2",
|
||||
lambda: ReferenceAgent(
|
||||
description="Reference Agent 2",
|
||||
system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo", temperature=0.5),
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
"ReferenceAgent3",
|
||||
lambda: ReferenceAgent(
|
||||
description="Reference Agent 3",
|
||||
system_messages=[SystemMessage("You are a helpful assistant that can answer questions.")],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo", temperature=1.0),
|
||||
),
|
||||
)
|
||||
runtime.register(
|
||||
"AggregatorAgent",
|
||||
lambda: AggregatorAgent(
|
||||
description="Aggregator Agent",
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"...synthesize these responses into a single, high-quality response... Responses from models:"
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo"),
|
||||
num_references=3,
|
||||
),
|
||||
)
|
||||
await runtime.publish_message(AggregatorTask(task="What are something fun to do in SF?"), namespace="default")
|
||||
while True:
|
||||
await runtime.process_next()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
asyncio.run(main())
|
||||
@@ -37,7 +37,6 @@ class ChatCompletionAgent(TypeRoutedAgent):
|
||||
|
||||
Args:
|
||||
description (str): The description of the agent.
|
||||
runtime (AgentRuntime): The runtime to register the agent.
|
||||
system_messages (List[SystemMessage]): The system messages to use for
|
||||
the ChatCompletion API.
|
||||
memory (ChatMemory[Message]): The memory to store and retrieve messages.
|
||||
|
||||
@@ -15,7 +15,6 @@ class OpenAIAssistantAgent(TypeRoutedAgent):
|
||||
|
||||
Args:
|
||||
description (str): The description of the agent.
|
||||
runtime (AgentRuntime): The runtime to register the agent.
|
||||
client (openai.AsyncClient): The client to use for the OpenAI API.
|
||||
assistant_id (str): The assistant ID to use for the OpenAI API.
|
||||
thread_id (str): The thread ID to use for the OpenAI API.
|
||||
|
||||
@@ -11,7 +11,6 @@ class UserProxyAgent(TypeRoutedAgent):
|
||||
|
||||
Args:
|
||||
description (str): The description of the agent.
|
||||
runtime (AgentRuntime): The runtime to register the agent.
|
||||
user_input_prompt (str): The console prompt to show to the user when asking for input.
|
||||
"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user