HeadAndTail memory; add group chat manager transitions; update software_consultancy example (#72)

This commit is contained in:
Eric Zhu
2024-06-12 15:28:00 -07:00
committed by GitHub
parent c36ea487e0
commit f754f3ce2f
4 changed files with 228 additions and 44 deletions

View File

@@ -8,19 +8,26 @@ or GitHub Codespaces to run this example.
import argparse
import asyncio
import base64
import logging
import aiofiles
import aiohttp
import openai
from agnext.application import SingleThreadedAgentRuntime
from agnext.chat.agents import ChatCompletionAgent, UserProxyAgent
from agnext.chat.memory import BufferedChatMemory
from agnext.chat.memory import HeadAndTailChatMemory
from agnext.chat.patterns.group_chat_manager import GroupChatManager
from agnext.chat.types import PublishNow
from agnext.components.models import OpenAI, SystemMessage
from agnext.components.tools import FunctionTool
from agnext.core import AgentRuntime
from markdownify import markdownify # type: ignore
from tqdm import tqdm
from typing_extensions import Annotated
sep = "+----------------------------------------------------------+"
async def get_user_input(prompt: str) -> Annotated[str, "The user input."]:
return await asyncio.get_event_loop().run_in_executor(None, input, prompt)
@@ -29,14 +36,15 @@ async def get_user_input(prompt: str) -> Annotated[str, "The user input."]:
async def confirm(message: str) -> None:
user_input = await get_user_input(f"{message} (yes/no): ")
if user_input.lower() not in ["yes", "y"]:
raise ValueError("Operation cancelled by system.")
raise ValueError(f"Operation cancelled: reason: {user_input}")
async def write_file(filename: str, content: str) -> None:
async def write_file(filename: str, content: str) -> str:
# Ask for confirmation first.
await confirm(f"Are you sure you want to write to {filename}?")
async with aiofiles.open(filename, "w") as file:
await file.write(content)
return f"Content written to {filename}."
async def execute_command(command: str) -> Annotated[str, "The standard output and error of the executed command."]:
@@ -53,18 +61,19 @@ async def execute_command(command: str) -> Annotated[str, "The standard output a
async def read_file(filename: str) -> Annotated[str, "The content of the file."]:
# Ask for confirmation first.
# await confirm(f"Are you sure you want to read {filename}?")
await confirm(f"Are you sure you want to read {filename}?")
async with aiofiles.open(filename, "r") as file:
return await file.read()
async def remove_file(filename: str) -> None:
async def remove_file(filename: str) -> str:
# Ask for confirmation first.
await confirm(f"Are you sure you want to remove {filename}?")
process = await asyncio.subprocess.create_subprocess_exec("rm", filename)
await process.wait()
if process.returncode != 0:
raise ValueError(f"Error occurred while removing file: {filename}")
return f"File removed: {filename}."
async def list_files(directory: str) -> Annotated[str, "The list of files in the directory."]:
@@ -82,14 +91,55 @@ async def list_files(directory: str) -> Annotated[str, "The list of files in the
return stdout.decode()
async def browse_web(url: str) -> Annotated[str, "The content of the web page in Markdown format."]:
# Ask for confirmation first.
await confirm(f"Are you sure you want to browse {url}?")
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
html = await response.text()
markdown = markdownify(html) # type: ignore
if isinstance(markdown, str):
return markdown
return f"Unable to parse content from {url}."
async def create_image(
description: Annotated[str, "Describe the image to create"],
filename: Annotated[str, "The path to save the created image"],
) -> str:
# Ask for confirmation first.
await confirm(f"Are you sure you want to create an image with description: {description}?")
# Use Dalle to generate an image from the description.
with tqdm(desc="Generating image...", leave=False) as pbar:
client = openai.AsyncClient()
response = await client.images.generate(model="dall-e-2", prompt=description, response_format="b64_json")
pbar.close()
assert len(response.data) > 0 and response.data[0].b64_json is not None
# Save the image to a file.
async with aiofiles.open(filename, "wb") as file:
image_data = base64.b64decode(response.data[0].b64_json)
await file.write(image_data)
return f"Image created and saved to {filename}."
def software_consultancy(runtime: AgentRuntime) -> UserProxyAgent: # type: ignore
developer = ChatCompletionAgent(
name="Developer",
description="A Python software developer.",
runtime=runtime,
system_messages=[SystemMessage("Your are a Python developer. Use your skills to write Python code.")],
system_messages=[
SystemMessage(
"Your are a Python developer. \n"
"You can read, write, and execute code. \n"
"You can browse files and directories. \n"
"You can also browse the web for documentation. \n"
"You are entering a work session with the customer, product manager, UX designer, and illustrator. \n"
"When you are given a task, you should immediately start working on it. \n"
"Be concise and deliver now."
)
],
model_client=OpenAI(model="gpt-4-turbo"),
memory=BufferedChatMemory(buffer_size=10),
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
tools=[
FunctionTool(
write_file,
@@ -101,45 +151,32 @@ def software_consultancy(runtime: AgentRuntime) -> UserProxyAgent: # type: igno
name="read_file",
description="Read code from a file.",
),
FunctionTool(list_files, name="list_files", description="List files in a directory."),
],
)
tester = ChatCompletionAgent(
name="Tester",
description="A Python software tester.",
runtime=runtime,
system_messages=[
SystemMessage(
"You are a Python tester. Use your skills to test code written by the developer and designer."
)
],
model_client=OpenAI(model="gpt-4-turbo"),
memory=BufferedChatMemory(buffer_size=10),
tools=[
FunctionTool(
execute_command,
name="execute_command",
description="Execute a unix shell command.",
),
FunctionTool(
read_file,
name="read_file",
description="Read code from a file.",
),
FunctionTool(list_files, name="list_files", description="List files in a directory."),
FunctionTool(browse_web, name="browse_web", description="Browse a web page."),
],
)
product_manager = ChatCompletionAgent(
name="ProductManager",
description="A product manager for a software consultancy. Interface with the customer and gather requirements for the developer.",
description="A product manager. "
"Responsible for interfacing with the customer, planning and managing the project. ",
runtime=runtime,
system_messages=[
SystemMessage(
"You are a product manager. Interface with the customer and gather requirements for the developer and user experience designer."
"You are a product manager. \n"
"You can browse files and directories. \n"
"You are entering a work session with the customer, developer, UX designer, and illustrator. \n"
"Keep the project on track. Don't hire any more people. \n"
"When a milestone is reached, stop and ask for customer feedback. Make sure the customer is satisfied. \n"
"Be VERY concise."
)
],
model_client=OpenAI(model="gpt-4-turbo"),
memory=BufferedChatMemory(buffer_size=10),
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
tools=[
FunctionTool(
read_file,
@@ -147,6 +184,7 @@ def software_consultancy(runtime: AgentRuntime) -> UserProxyAgent: # type: igno
description="Read from a file.",
),
FunctionTool(list_files, name="list_files", description="List files in a directory."),
FunctionTool(browse_web, name="browse_web", description="Browse a web page."),
],
)
ux_designer = ChatCompletionAgent(
@@ -154,10 +192,17 @@ def software_consultancy(runtime: AgentRuntime) -> UserProxyAgent: # type: igno
description="A user experience designer for creating user interfaces.",
runtime=runtime,
system_messages=[
SystemMessage("You are a user experience designer. Design user interfaces for the developer.")
SystemMessage(
"You are a user experience designer. \n"
"You can create user interfaces from descriptions. \n"
"You can browse files and directories. \n"
"You are entering a work session with the customer, developer, product manager, and illustrator. \n"
"When you are given a task, you should immediately start working on it. \n"
"Be concise and deliver now."
)
],
model_client=OpenAI(model="gpt-4-turbo"),
memory=BufferedChatMemory(buffer_size=10),
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
tools=[
FunctionTool(
write_file,
@@ -172,20 +217,43 @@ def software_consultancy(runtime: AgentRuntime) -> UserProxyAgent: # type: igno
FunctionTool(list_files, name="list_files", description="List files in a directory."),
],
)
illustrator = ChatCompletionAgent(
name="Illustrator",
description="An illustrator for creating images.",
runtime=runtime,
system_messages=[
SystemMessage(
"You are an illustrator. "
"You can create images from descriptions. "
"You are entering a work session with the customer, developer, product manager, and UX designer. \n"
"When you are given a task, you should immediately start working on it. \n"
"Be concise and deliver now."
)
],
model_client=OpenAI(model="gpt-4-turbo"),
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
tools=[
FunctionTool(
create_image,
name="create_image",
description="Create an image from a description.",
),
],
)
customer = UserProxyAgent(
name="Customer",
description="A customer requesting for help.",
runtime=runtime,
user_input_prompt=f"{'-'*50}\nYou:\n",
user_input_prompt=f"{sep}\nYou:\n",
)
_ = GroupChatManager(
name="GroupChatManager",
description="A group chat manager.",
runtime=runtime,
memory=BufferedChatMemory(buffer_size=10),
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
model_client=OpenAI(model="gpt-4-turbo"),
participants=[developer, tester, product_manager, ux_designer, customer],
on_message_received=lambda message: print(f"{'-'*50}\n{message.source}: {message.content}"),
participants=[developer, product_manager, ux_designer, illustrator, customer],
on_message_received=lambda message: print(f"{sep}\n{message.source}: {message.content}"),
)
return customer
@@ -202,11 +270,34 @@ async def main() -> None:
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Chat with software development team.")
description = "Work with a software development consultancy to create your own Python application."
art = r"""
+----------------------------------------------------------+
| ____ __ _ |
| / ___| ___ / _| |___ ____ _ _ __ ___ |
| \___ \ / _ \| |_| __\ \ /\ / / _` | '__/ _ \ |
| ___) | (_) | _| |_ \ V V / (_| | | | __/ |
| |____/ \___/|_| \__| \_/\_/ \__,_|_| \___| |
| |
| ____ _ _ |
| / ___|___ _ __ ___ _ _| | |_ __ _ _ __ ___ _ _ |
| | | / _ \| '_ \/ __| | | | | __/ _` | '_ \ / __| | | | |
| | |__| (_) | | | \__ \ |_| | | || (_| | | | | (__| |_| | |
| \____\___/|_| |_|___/\__,_|_|\__\__,_|_| |_|\___|\__, | |
| |___/ |
| |
+----------------------------------------------------------+
| Work with a software development consultancy to create |
| your own Python application. You can start by greeting |
| the team! |
+----------------------------------------------------------+
"""
parser = argparse.ArgumentParser(description="Software consultancy demo.")
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")
args = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
print(art)
asyncio.run(main())

View File

@@ -1,4 +1,5 @@
from ._base import ChatMemory
from ._buffered import BufferedChatMemory
from ._head_and_tail import HeadAndTailChatMemory
__all__ = ["ChatMemory", "BufferedChatMemory"]
__all__ = ["ChatMemory", "BufferedChatMemory", "HeadAndTailChatMemory"]

View File

@@ -0,0 +1,66 @@
from typing import Any, List, Mapping
from ...components.models import FunctionExecutionResultMessage
from ..types import FunctionCallMessage, Message, TextMessage
from ._base import ChatMemory
class HeadAndTailChatMemory(ChatMemory):
"""A chat memory that keeps a view of the first n and last m messages,
where n is the head size and m is the tail size. The head and tail sizes
are set at initialization.
Args:
head_size (int): The size of the head.
tail_size (int): The size of the tail.
"""
def __init__(self, head_size: int, tail_size: int) -> None:
self._messages: List[Message] = []
self._head_size = head_size
self._tail_size = tail_size
async def add_message(self, message: Message) -> None:
"""Add a message to the memory."""
self._messages.append(message)
async def get_messages(self) -> List[Message]:
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
head_messages = self._messages[: self._head_size]
# Handle the last message is a function call message.
if head_messages and isinstance(head_messages[-1], FunctionCallMessage):
# Remove the last message from the head.
head_messages = head_messages[:-1]
tail_messages = self._messages[-self._tail_size :]
# Handle the first message is a function call result message.
if tail_messages and isinstance(tail_messages[0], FunctionExecutionResultMessage):
# Remove the first message from the tail.
tail_messages = tail_messages[1:]
num_skipped = len(self._messages) - self._head_size - self._tail_size
if num_skipped <= 0:
# If there are not enough messages to fill the head and tail,
# return all messages.
return self._messages
placeholder_messages = [TextMessage(content=f"Skipped {num_skipped} messages.", source="System")]
return head_messages + placeholder_messages + tail_messages
async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []
def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
"head_size": self._head_size,
"tail_size": self._tail_size,
"placeholder_message": self._placeholder_message,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
self._head_size = state["head_size"]
self._tail_size = state["tail_size"]
self._placeholder_message = state["placeholder_message"]

View File

@@ -26,6 +26,11 @@ class GroupChatManager(TypeRoutedAgent):
If not provided, the agent will select the next speaker from the list of participants
according to the order given.
termination_word (str, optional): The word that terminates the group chat. Defaults to "TERMINATE".
transitions (Mapping[Agent, List[Agent]], optional): The transitions between agents.
Keys are the agents, and values are the list of agents that can follow the key agent. Defaults to {}.A
If provided, the group chat manager will use the transitions to select the next speaker.
If a transition is not provided for an agent, the choices fallback to all participants.
This setting is only used when a model client is provided.
on_message_received (Callable[[TextMessage], None], optional): A custom handler to call when a message is received.
Defaults to None.
"""
@@ -39,6 +44,7 @@ class GroupChatManager(TypeRoutedAgent):
memory: ChatMemory,
model_client: ChatCompletionClient | None = None,
termination_word: str = "TERMINATE",
transitions: Mapping[Agent, List[Agent]] = {},
on_message_received: Callable[[TextMessage], None] | None = None,
):
super().__init__(name, description, runtime)
@@ -46,6 +52,18 @@ class GroupChatManager(TypeRoutedAgent):
self._client = model_client
self._participants = participants
self._termination_word = termination_word
for key, value in transitions.items():
if not value:
# Make sure no empty transitions are provided.
raise ValueError(f"Empty transition list provided for {key.name}.")
if key not in participants:
# Make sure all keys are in the list of participants.
raise ValueError(f"Transition key {key.name} not found in participants.")
for v in value:
if v not in participants:
# Make sure all values are in the list of participants.
raise ValueError(f"Transition value {v.name} not found in participants.")
self._tranistiions = transitions
self._on_message_received = on_message_received
@message_handler()
@@ -69,13 +87,13 @@ class GroupChatManager(TypeRoutedAgent):
# Save the message to chat memory.
await self._memory.add_message(message)
# Get the last speaker.
last_speaker_name = message.source
last_speaker_index = next((i for i, p in enumerate(self._participants) if p.name == last_speaker_name), None)
# Select speaker.
if self._client is None:
# If no model client is provided, select the next speaker from the list of participants.
last_speaker_name = message.source
last_speaker_index = next(
(i for i, p in enumerate(self._participants) if p.name == last_speaker_name), None
)
if last_speaker_index is None:
# If the last speaker is not found, select the first speaker in the list.
next_speaker_index = 0
@@ -83,8 +101,16 @@ class GroupChatManager(TypeRoutedAgent):
next_speaker_index = (last_speaker_index + 1) % len(self._participants)
speaker = self._participants[next_speaker_index]
else:
# If a model client is provided, select the speaker based on the model output.
speaker = await select_speaker(self._memory, self._client, self._participants)
# If a model client is provided, select the speaker based on the transitions and the model.
candidates = self._participants
if last_speaker_index is not None:
last_speaker = self._participants[last_speaker_index]
if self._tranistiions.get(last_speaker) is not None:
candidates = self._tranistiions[last_speaker]
if len(candidates) == 1:
speaker = candidates[0]
else:
speaker = await select_speaker(self._memory, self._client, candidates)
# Send the message to the selected speaker to ask it to publish a response.
await self._send_message(PublishNow(), speaker)