mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-26 05:08:08 -05:00
HeadAndTail memory; add group chat manager transitions; update software_consultancy example (#72)
This commit is contained in:
@@ -8,19 +8,26 @@ or GitHub Codespaces to run this example.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import openai
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.chat.agents import ChatCompletionAgent, UserProxyAgent
|
||||
from agnext.chat.memory import BufferedChatMemory
|
||||
from agnext.chat.memory import HeadAndTailChatMemory
|
||||
from agnext.chat.patterns.group_chat_manager import GroupChatManager
|
||||
from agnext.chat.types import PublishNow
|
||||
from agnext.components.models import OpenAI, SystemMessage
|
||||
from agnext.components.tools import FunctionTool
|
||||
from agnext.core import AgentRuntime
|
||||
from markdownify import markdownify # type: ignore
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import Annotated
|
||||
|
||||
sep = "+----------------------------------------------------------+"
|
||||
|
||||
|
||||
async def get_user_input(prompt: str) -> Annotated[str, "The user input."]:
|
||||
return await asyncio.get_event_loop().run_in_executor(None, input, prompt)
|
||||
@@ -29,14 +36,15 @@ async def get_user_input(prompt: str) -> Annotated[str, "The user input."]:
|
||||
async def confirm(message: str) -> None:
|
||||
user_input = await get_user_input(f"{message} (yes/no): ")
|
||||
if user_input.lower() not in ["yes", "y"]:
|
||||
raise ValueError("Operation cancelled by system.")
|
||||
raise ValueError(f"Operation cancelled: reason: {user_input}")
|
||||
|
||||
|
||||
async def write_file(filename: str, content: str) -> None:
|
||||
async def write_file(filename: str, content: str) -> str:
|
||||
# Ask for confirmation first.
|
||||
await confirm(f"Are you sure you want to write to {filename}?")
|
||||
async with aiofiles.open(filename, "w") as file:
|
||||
await file.write(content)
|
||||
return f"Content written to {filename}."
|
||||
|
||||
|
||||
async def execute_command(command: str) -> Annotated[str, "The standard output and error of the executed command."]:
|
||||
@@ -53,18 +61,19 @@ async def execute_command(command: str) -> Annotated[str, "The standard output a
|
||||
|
||||
async def read_file(filename: str) -> Annotated[str, "The content of the file."]:
|
||||
# Ask for confirmation first.
|
||||
# await confirm(f"Are you sure you want to read {filename}?")
|
||||
await confirm(f"Are you sure you want to read {filename}?")
|
||||
async with aiofiles.open(filename, "r") as file:
|
||||
return await file.read()
|
||||
|
||||
|
||||
async def remove_file(filename: str) -> None:
|
||||
async def remove_file(filename: str) -> str:
|
||||
# Ask for confirmation first.
|
||||
await confirm(f"Are you sure you want to remove {filename}?")
|
||||
process = await asyncio.subprocess.create_subprocess_exec("rm", filename)
|
||||
await process.wait()
|
||||
if process.returncode != 0:
|
||||
raise ValueError(f"Error occurred while removing file: {filename}")
|
||||
return f"File removed: {filename}."
|
||||
|
||||
|
||||
async def list_files(directory: str) -> Annotated[str, "The list of files in the directory."]:
|
||||
@@ -82,14 +91,55 @@ async def list_files(directory: str) -> Annotated[str, "The list of files in the
|
||||
return stdout.decode()
|
||||
|
||||
|
||||
async def browse_web(url: str) -> Annotated[str, "The content of the web page in Markdown format."]:
|
||||
# Ask for confirmation first.
|
||||
await confirm(f"Are you sure you want to browse {url}?")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
html = await response.text()
|
||||
markdown = markdownify(html) # type: ignore
|
||||
if isinstance(markdown, str):
|
||||
return markdown
|
||||
return f"Unable to parse content from {url}."
|
||||
|
||||
|
||||
async def create_image(
|
||||
description: Annotated[str, "Describe the image to create"],
|
||||
filename: Annotated[str, "The path to save the created image"],
|
||||
) -> str:
|
||||
# Ask for confirmation first.
|
||||
await confirm(f"Are you sure you want to create an image with description: {description}?")
|
||||
# Use Dalle to generate an image from the description.
|
||||
with tqdm(desc="Generating image...", leave=False) as pbar:
|
||||
client = openai.AsyncClient()
|
||||
response = await client.images.generate(model="dall-e-2", prompt=description, response_format="b64_json")
|
||||
pbar.close()
|
||||
assert len(response.data) > 0 and response.data[0].b64_json is not None
|
||||
# Save the image to a file.
|
||||
async with aiofiles.open(filename, "wb") as file:
|
||||
image_data = base64.b64decode(response.data[0].b64_json)
|
||||
await file.write(image_data)
|
||||
return f"Image created and saved to {filename}."
|
||||
|
||||
|
||||
def software_consultancy(runtime: AgentRuntime) -> UserProxyAgent: # type: ignore
|
||||
developer = ChatCompletionAgent(
|
||||
name="Developer",
|
||||
description="A Python software developer.",
|
||||
runtime=runtime,
|
||||
system_messages=[SystemMessage("Your are a Python developer. Use your skills to write Python code.")],
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"Your are a Python developer. \n"
|
||||
"You can read, write, and execute code. \n"
|
||||
"You can browse files and directories. \n"
|
||||
"You can also browse the web for documentation. \n"
|
||||
"You are entering a work session with the customer, product manager, UX designer, and illustrator. \n"
|
||||
"When you are given a task, you should immediately start working on it. \n"
|
||||
"Be concise and deliver now."
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
|
||||
tools=[
|
||||
FunctionTool(
|
||||
write_file,
|
||||
@@ -101,45 +151,32 @@ def software_consultancy(runtime: AgentRuntime) -> UserProxyAgent: # type: igno
|
||||
name="read_file",
|
||||
description="Read code from a file.",
|
||||
),
|
||||
FunctionTool(list_files, name="list_files", description="List files in a directory."),
|
||||
],
|
||||
)
|
||||
tester = ChatCompletionAgent(
|
||||
name="Tester",
|
||||
description="A Python software tester.",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"You are a Python tester. Use your skills to test code written by the developer and designer."
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
tools=[
|
||||
FunctionTool(
|
||||
execute_command,
|
||||
name="execute_command",
|
||||
description="Execute a unix shell command.",
|
||||
),
|
||||
FunctionTool(
|
||||
read_file,
|
||||
name="read_file",
|
||||
description="Read code from a file.",
|
||||
),
|
||||
FunctionTool(list_files, name="list_files", description="List files in a directory."),
|
||||
FunctionTool(browse_web, name="browse_web", description="Browse a web page."),
|
||||
],
|
||||
)
|
||||
product_manager = ChatCompletionAgent(
|
||||
name="ProductManager",
|
||||
description="A product manager for a software consultancy. Interface with the customer and gather requirements for the developer.",
|
||||
description="A product manager. "
|
||||
"Responsible for interfacing with the customer, planning and managing the project. ",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"You are a product manager. Interface with the customer and gather requirements for the developer and user experience designer."
|
||||
"You are a product manager. \n"
|
||||
"You can browse files and directories. \n"
|
||||
"You are entering a work session with the customer, developer, UX designer, and illustrator. \n"
|
||||
"Keep the project on track. Don't hire any more people. \n"
|
||||
"When a milestone is reached, stop and ask for customer feedback. Make sure the customer is satisfied. \n"
|
||||
"Be VERY concise."
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
|
||||
tools=[
|
||||
FunctionTool(
|
||||
read_file,
|
||||
@@ -147,6 +184,7 @@ def software_consultancy(runtime: AgentRuntime) -> UserProxyAgent: # type: igno
|
||||
description="Read from a file.",
|
||||
),
|
||||
FunctionTool(list_files, name="list_files", description="List files in a directory."),
|
||||
FunctionTool(browse_web, name="browse_web", description="Browse a web page."),
|
||||
],
|
||||
)
|
||||
ux_designer = ChatCompletionAgent(
|
||||
@@ -154,10 +192,17 @@ def software_consultancy(runtime: AgentRuntime) -> UserProxyAgent: # type: igno
|
||||
description="A user experience designer for creating user interfaces.",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage("You are a user experience designer. Design user interfaces for the developer.")
|
||||
SystemMessage(
|
||||
"You are a user experience designer. \n"
|
||||
"You can create user interfaces from descriptions. \n"
|
||||
"You can browse files and directories. \n"
|
||||
"You are entering a work session with the customer, developer, product manager, and illustrator. \n"
|
||||
"When you are given a task, you should immediately start working on it. \n"
|
||||
"Be concise and deliver now."
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
|
||||
tools=[
|
||||
FunctionTool(
|
||||
write_file,
|
||||
@@ -172,20 +217,43 @@ def software_consultancy(runtime: AgentRuntime) -> UserProxyAgent: # type: igno
|
||||
FunctionTool(list_files, name="list_files", description="List files in a directory."),
|
||||
],
|
||||
)
|
||||
illustrator = ChatCompletionAgent(
|
||||
name="Illustrator",
|
||||
description="An illustrator for creating images.",
|
||||
runtime=runtime,
|
||||
system_messages=[
|
||||
SystemMessage(
|
||||
"You are an illustrator. "
|
||||
"You can create images from descriptions. "
|
||||
"You are entering a work session with the customer, developer, product manager, and UX designer. \n"
|
||||
"When you are given a task, you should immediately start working on it. \n"
|
||||
"Be concise and deliver now."
|
||||
)
|
||||
],
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
|
||||
tools=[
|
||||
FunctionTool(
|
||||
create_image,
|
||||
name="create_image",
|
||||
description="Create an image from a description.",
|
||||
),
|
||||
],
|
||||
)
|
||||
customer = UserProxyAgent(
|
||||
name="Customer",
|
||||
description="A customer requesting for help.",
|
||||
runtime=runtime,
|
||||
user_input_prompt=f"{'-'*50}\nYou:\n",
|
||||
user_input_prompt=f"{sep}\nYou:\n",
|
||||
)
|
||||
_ = GroupChatManager(
|
||||
name="GroupChatManager",
|
||||
description="A group chat manager.",
|
||||
runtime=runtime,
|
||||
memory=BufferedChatMemory(buffer_size=10),
|
||||
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
|
||||
model_client=OpenAI(model="gpt-4-turbo"),
|
||||
participants=[developer, tester, product_manager, ux_designer, customer],
|
||||
on_message_received=lambda message: print(f"{'-'*50}\n{message.source}: {message.content}"),
|
||||
participants=[developer, product_manager, ux_designer, illustrator, customer],
|
||||
on_message_received=lambda message: print(f"{sep}\n{message.source}: {message.content}"),
|
||||
)
|
||||
return customer
|
||||
|
||||
@@ -202,11 +270,34 @@ async def main() -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Chat with software development team.")
|
||||
description = "Work with a software development consultancy to create your own Python application."
|
||||
art = r"""
|
||||
+----------------------------------------------------------+
|
||||
| ____ __ _ |
|
||||
| / ___| ___ / _| |___ ____ _ _ __ ___ |
|
||||
| \___ \ / _ \| |_| __\ \ /\ / / _` | '__/ _ \ |
|
||||
| ___) | (_) | _| |_ \ V V / (_| | | | __/ |
|
||||
| |____/ \___/|_| \__| \_/\_/ \__,_|_| \___| |
|
||||
| |
|
||||
| ____ _ _ |
|
||||
| / ___|___ _ __ ___ _ _| | |_ __ _ _ __ ___ _ _ |
|
||||
| | | / _ \| '_ \/ __| | | | | __/ _` | '_ \ / __| | | | |
|
||||
| | |__| (_) | | | \__ \ |_| | | || (_| | | | | (__| |_| | |
|
||||
| \____\___/|_| |_|___/\__,_|_|\__\__,_|_| |_|\___|\__, | |
|
||||
| |___/ |
|
||||
| |
|
||||
+----------------------------------------------------------+
|
||||
| Work with a software development consultancy to create |
|
||||
| your own Python application. You can start by greeting |
|
||||
| the team! |
|
||||
+----------------------------------------------------------+
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Software consultancy demo.")
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")
|
||||
args = parser.parse_args()
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
|
||||
print(art)
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from ._base import ChatMemory
|
||||
from ._buffered import BufferedChatMemory
|
||||
from ._head_and_tail import HeadAndTailChatMemory
|
||||
|
||||
__all__ = ["ChatMemory", "BufferedChatMemory"]
|
||||
__all__ = ["ChatMemory", "BufferedChatMemory", "HeadAndTailChatMemory"]
|
||||
|
||||
66
src/agnext/chat/memory/_head_and_tail.py
Normal file
66
src/agnext/chat/memory/_head_and_tail.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from ...components.models import FunctionExecutionResultMessage
|
||||
from ..types import FunctionCallMessage, Message, TextMessage
|
||||
from ._base import ChatMemory
|
||||
|
||||
|
||||
class HeadAndTailChatMemory(ChatMemory):
|
||||
"""A chat memory that keeps a view of the first n and last m messages,
|
||||
where n is the head size and m is the tail size. The head and tail sizes
|
||||
are set at initialization.
|
||||
|
||||
Args:
|
||||
head_size (int): The size of the head.
|
||||
tail_size (int): The size of the tail.
|
||||
"""
|
||||
|
||||
def __init__(self, head_size: int, tail_size: int) -> None:
|
||||
self._messages: List[Message] = []
|
||||
self._head_size = head_size
|
||||
self._tail_size = tail_size
|
||||
|
||||
async def add_message(self, message: Message) -> None:
|
||||
"""Add a message to the memory."""
|
||||
self._messages.append(message)
|
||||
|
||||
async def get_messages(self) -> List[Message]:
|
||||
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
|
||||
head_messages = self._messages[: self._head_size]
|
||||
# Handle the last message is a function call message.
|
||||
if head_messages and isinstance(head_messages[-1], FunctionCallMessage):
|
||||
# Remove the last message from the head.
|
||||
head_messages = head_messages[:-1]
|
||||
|
||||
tail_messages = self._messages[-self._tail_size :]
|
||||
# Handle the first message is a function call result message.
|
||||
if tail_messages and isinstance(tail_messages[0], FunctionExecutionResultMessage):
|
||||
# Remove the first message from the tail.
|
||||
tail_messages = tail_messages[1:]
|
||||
|
||||
num_skipped = len(self._messages) - self._head_size - self._tail_size
|
||||
if num_skipped <= 0:
|
||||
# If there are not enough messages to fill the head and tail,
|
||||
# return all messages.
|
||||
return self._messages
|
||||
|
||||
placeholder_messages = [TextMessage(content=f"Skipped {num_skipped} messages.", source="System")]
|
||||
return head_messages + placeholder_messages + tail_messages
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the message memory."""
|
||||
self._messages = []
|
||||
|
||||
def save_state(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
"messages": [message for message in self._messages],
|
||||
"head_size": self._head_size,
|
||||
"tail_size": self._tail_size,
|
||||
"placeholder_message": self._placeholder_message,
|
||||
}
|
||||
|
||||
def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
self._messages = state["messages"]
|
||||
self._head_size = state["head_size"]
|
||||
self._tail_size = state["tail_size"]
|
||||
self._placeholder_message = state["placeholder_message"]
|
||||
@@ -26,6 +26,11 @@ class GroupChatManager(TypeRoutedAgent):
|
||||
If not provided, the agent will select the next speaker from the list of participants
|
||||
according to the order given.
|
||||
termination_word (str, optional): The word that terminates the group chat. Defaults to "TERMINATE".
|
||||
transitions (Mapping[Agent, List[Agent]], optional): The transitions between agents.
|
||||
Keys are the agents, and values are the list of agents that can follow the key agent. Defaults to {}.A
|
||||
If provided, the group chat manager will use the transitions to select the next speaker.
|
||||
If a transition is not provided for an agent, the choices fallback to all participants.
|
||||
This setting is only used when a model client is provided.
|
||||
on_message_received (Callable[[TextMessage], None], optional): A custom handler to call when a message is received.
|
||||
Defaults to None.
|
||||
"""
|
||||
@@ -39,6 +44,7 @@ class GroupChatManager(TypeRoutedAgent):
|
||||
memory: ChatMemory,
|
||||
model_client: ChatCompletionClient | None = None,
|
||||
termination_word: str = "TERMINATE",
|
||||
transitions: Mapping[Agent, List[Agent]] = {},
|
||||
on_message_received: Callable[[TextMessage], None] | None = None,
|
||||
):
|
||||
super().__init__(name, description, runtime)
|
||||
@@ -46,6 +52,18 @@ class GroupChatManager(TypeRoutedAgent):
|
||||
self._client = model_client
|
||||
self._participants = participants
|
||||
self._termination_word = termination_word
|
||||
for key, value in transitions.items():
|
||||
if not value:
|
||||
# Make sure no empty transitions are provided.
|
||||
raise ValueError(f"Empty transition list provided for {key.name}.")
|
||||
if key not in participants:
|
||||
# Make sure all keys are in the list of participants.
|
||||
raise ValueError(f"Transition key {key.name} not found in participants.")
|
||||
for v in value:
|
||||
if v not in participants:
|
||||
# Make sure all values are in the list of participants.
|
||||
raise ValueError(f"Transition value {v.name} not found in participants.")
|
||||
self._tranistiions = transitions
|
||||
self._on_message_received = on_message_received
|
||||
|
||||
@message_handler()
|
||||
@@ -69,13 +87,13 @@ class GroupChatManager(TypeRoutedAgent):
|
||||
# Save the message to chat memory.
|
||||
await self._memory.add_message(message)
|
||||
|
||||
# Get the last speaker.
|
||||
last_speaker_name = message.source
|
||||
last_speaker_index = next((i for i, p in enumerate(self._participants) if p.name == last_speaker_name), None)
|
||||
|
||||
# Select speaker.
|
||||
if self._client is None:
|
||||
# If no model client is provided, select the next speaker from the list of participants.
|
||||
last_speaker_name = message.source
|
||||
last_speaker_index = next(
|
||||
(i for i, p in enumerate(self._participants) if p.name == last_speaker_name), None
|
||||
)
|
||||
if last_speaker_index is None:
|
||||
# If the last speaker is not found, select the first speaker in the list.
|
||||
next_speaker_index = 0
|
||||
@@ -83,8 +101,16 @@ class GroupChatManager(TypeRoutedAgent):
|
||||
next_speaker_index = (last_speaker_index + 1) % len(self._participants)
|
||||
speaker = self._participants[next_speaker_index]
|
||||
else:
|
||||
# If a model client is provided, select the speaker based on the model output.
|
||||
speaker = await select_speaker(self._memory, self._client, self._participants)
|
||||
# If a model client is provided, select the speaker based on the transitions and the model.
|
||||
candidates = self._participants
|
||||
if last_speaker_index is not None:
|
||||
last_speaker = self._participants[last_speaker_index]
|
||||
if self._tranistiions.get(last_speaker) is not None:
|
||||
candidates = self._tranistiions[last_speaker]
|
||||
if len(candidates) == 1:
|
||||
speaker = candidates[0]
|
||||
else:
|
||||
speaker = await select_speaker(self._memory, self._client, candidates)
|
||||
|
||||
# Send the message to the selected speaker to ask it to publish a response.
|
||||
await self._send_message(PublishNow(), speaker)
|
||||
|
||||
Reference in New Issue
Block a user