Improve assistant and chess examples to make them more robust and better presentation. (#58)

* Improve assistant and chess examples to make them more robust and better presentation.

* type dep

* format
This commit is contained in:
Eric Zhu
2024-06-08 01:27:27 -07:00
committed by GitHub
parent 21b730e7c6
commit 37cc6bc12a
5 changed files with 384 additions and 168 deletions

239
examples/assistant.py Normal file
View File

@@ -0,0 +1,239 @@
"""This is an example of a chat with an OpenAIAssistantAgent.
You must have OPENAI_API_KEY set up in your environment to
run this example.
"""
import os
import re
from typing import Any, List
import aiofiles
import openai
from agnext.application import SingleThreadedAgentRuntime
from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.patterns.group_chat import GroupChatOutput
from agnext.chat.patterns.two_agent_chat import TwoAgentChat
from agnext.chat.types import RespondNow, TextMessage
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import AgentRuntime, CancellationToken
from openai import AsyncAssistantEventHandler
from openai.types.beta.thread import ToolResources
from openai.types.beta.threads import Message, Text, TextDelta
from openai.types.beta.threads.runs import RunStep, RunStepDelta
from typing_extensions import override
class TwoAgentChatOutput(GroupChatOutput): # type: ignore
def on_message_received(self, message: Any) -> None:
pass
def get_output(self) -> Any:
return None
def reset(self) -> None:
pass
sep = "-" * 50
class UserProxyAgent(BaseChatAgent, TypeRoutedAgent): # type: ignore
def __init__(
self,
name: str,
runtime: AgentRuntime,
client: openai.AsyncClient,
assistant_id: str,
thread_id: str,
vector_store_id: str,
) -> None: # type: ignore
super().__init__(
name=name,
description="A human user",
runtime=runtime,
)
self._client = client
self._assistant_id = assistant_id
self._thread_id = thread_id
self._vector_store_id = vector_store_id
@message_handler() # type: ignore
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # type: ignore
# TODO: render image if message has image.
# print(f"{message.source}: {message.content}")
pass
@message_handler() # type: ignore
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage: # type: ignore
while True:
user_input = input(f"\n{sep}\nYou: ")
# Parse upload file command '[upload code_interpreter | file_search filename]'.
match = re.search(r"\[upload\s+(code_interpreter|file_search)\s+(.+)\]", user_input)
if match:
# Purpose of the file.
purpose = match.group(1)
# Extract file path.
file_path = match.group(2)
if not os.path.exists(file_path):
print(f"File not found: {file_path}")
continue
# Filename.
file_name = os.path.basename(file_path)
# Read file content.
async with aiofiles.open(file_path, "rb") as f:
file_content = await f.read()
if purpose == "code_interpreter":
# Upload file.
file = await self._client.files.create(file=(file_name, file_content), purpose="assistants")
# Get existing file ids from tool resources.
thread = await self._client.beta.threads.retrieve(thread_id=self._thread_id)
tool_resources: ToolResources = thread.tool_resources if thread.tool_resources else ToolResources()
assert tool_resources.code_interpreter is not None
if tool_resources.code_interpreter.file_ids:
file_ids = tool_resources.code_interpreter.file_ids
else:
file_ids = [file.id]
# Update thread with new file.
await self._client.beta.threads.update(
thread_id=self._thread_id,
tool_resources={"code_interpreter": {"file_ids": file_ids}},
)
elif purpose == "file_search":
# Upload file to vector store.
file_batch = await self._client.beta.vector_stores.file_batches.upload_and_poll(
vector_store_id=self._vector_store_id,
files=[(file_name, file_content)],
)
assert file_batch.status == "completed"
print(f"Uploaded file: {file_name}")
continue
elif user_input.startswith("[upload"):
print("Invalid upload command. Please use '[upload code_interpreter | file_search filename]'.")
continue
else:
# Send user input to assistant.
return TextMessage(content=user_input, source=self.name)
class EventHandler(AsyncAssistantEventHandler):
@override
async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
print(delta.value, end="", flush=True)
@override
async def on_run_step_created(self, run_step: RunStep) -> None:
details = run_step.step_details
if details.type == "tool_calls":
for tool in details.tool_calls:
if tool.type == "code_interpreter":
print("\nGenerating code to interpret:\n\n```python")
@override
async def on_run_step_done(self, run_step: RunStep) -> None:
details = run_step.step_details
if details.type == "tool_calls":
for tool in details.tool_calls:
if tool.type == "code_interpreter":
print("\n```\nExecuting code...")
@override
async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
details = delta.step_details
if details is not None and details.type == "tool_calls":
for tool in details.tool_calls or []:
if tool.type == "code_interpreter" and tool.code_interpreter and tool.code_interpreter.input:
print(tool.code_interpreter.input, end="", flush=True)
@override
async def on_message_created(self, message: Message) -> None:
print(f"{sep}\nAssistant:\n")
@override
async def on_message_done(self, message: Message) -> None:
# print a citation to the file searched
if not message.content:
return
content = message.content[0]
if not content.type == "text":
return
text_content = content.text
annotations = text_content.annotations
citations: List[str] = []
for index, annotation in enumerate(annotations):
text_content.value = text_content.value.replace(annotation.text, f"[{index}]")
if file_citation := getattr(annotation, "file_citation", None):
client = openai.AsyncClient()
cited_file = await client.files.retrieve(file_citation.file_id)
citations.append(f"[{index}] {cited_file.filename}")
if citations:
print("\n".join(citations))
def assistant_chat(runtime: AgentRuntime) -> TwoAgentChat: # type: ignore
oai_assistant = openai.beta.assistants.create(
model="gpt-4-turbo",
description="An AI assistant that helps with everyday tasks.",
instructions="Help the user with their task.",
tools=[{"type": "code_interpreter"}, {"type": "file_search"}],
)
vector_store = openai.beta.vector_stores.create()
thread = openai.beta.threads.create(
tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
)
assistant = OpenAIAssistantAgent(
name="Assistant",
description="An AI assistant that helps with everyday tasks.",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=oai_assistant.id,
thread_id=thread.id,
assistant_event_handler_factory=lambda: EventHandler(),
)
user = UserProxyAgent(
name="User",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=oai_assistant.id,
thread_id=thread.id,
vector_store_id=vector_store.id,
)
return TwoAgentChat(
name="AssistantChat",
description="A chat with an AI assistant",
runtime=runtime,
first_speaker=assistant,
second_speaker=user,
num_rounds=100,
output=TwoAgentChatOutput(),
)
async def main() -> None:
usage = """Chat with an AI assistant backed by OpenAI Assistant API.
You can upload files to the assistant using the command:
[upload code_interpreter | file_search filename]
where 'code_interpreter' or 'file_search' is the purpose of the file and
'filename' is the path to the file. For example:
[upload code_interpreter data.csv]
This will upload data.csv to the assistant for use with the code interpreter tool.
"""
runtime = SingleThreadedAgentRuntime()
chat = assistant_chat(runtime)
print(usage)
future = runtime.send_message(
TextMessage(content="Hello.", source="User"),
chat,
)
while not future.done():
await runtime.process_next()
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -7,21 +7,19 @@ You must have OPENAI_API_KEY set up in your environment to run this example.
import argparse
import asyncio
import logging
from typing import Annotated
from typing import Annotated, Literal
from agnext.application import SingleThreadedAgentRuntime
from agnext.chat.agents.chat_completion_agent import ChatCompletionAgent
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
from agnext.chat.patterns.two_agent_chat import TwoAgentChat
from agnext.chat.types import TextMessage
from agnext.components.models import OpenAI, SystemMessage
from agnext.components.tools import FunctionTool
from agnext.core import AgentRuntime
from chess import SQUARE_NAMES, Board, Move
from chess import BLACK, SQUARE_NAMES, WHITE, Board, Move
from chess import piece_name as get_piece_name
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
class ChessGameOutput(GroupChatOutput): # type: ignore
def on_message_received(self, message: TextMessage) -> None: # type: ignore
@@ -34,34 +32,136 @@ class ChessGameOutput(GroupChatOutput): # type: ignore
pass
def validate_turn(board: Board, player: Literal["white", "black"]) -> None:
"""Validate that it is the player's turn to move."""
last_move = board.peek() if board.move_stack else None
if last_move is not None:
if player == "white" and board.color_at(last_move.to_square) == WHITE:
raise ValueError("It is not your turn to move. Wait for black to move.")
if player == "black" and board.color_at(last_move.to_square) == BLACK:
raise ValueError("It is not your turn to move. Wait for white to move.")
elif last_move is None and player != "white":
raise ValueError("It is not your turn to move. Wait for white to move first.")
def get_legal_moves(
board: Board, player: Literal["white", "black"]
) -> Annotated[str, "A list of legal moves in UCI format."]:
"""Get legal moves for the given player."""
validate_turn(board, player)
legal_moves = list(board.legal_moves)
if player == "black":
legal_moves = [move for move in legal_moves if board.color_at(move.from_square) == BLACK]
elif player == "white":
legal_moves = [move for move in legal_moves if board.color_at(move.from_square) == WHITE]
else:
raise ValueError("Invalid player, must be either 'black' or 'white'.")
if not legal_moves:
return "No legal moves. The game is over."
return "Possible moves are: " + ", ".join([move.uci() for move in legal_moves])
def get_board(board: Board) -> str:
return str(board)
def make_move(
board: Board,
player: Literal["white", "black"],
thinking: Annotated[str, "Thinking for the move."],
move: Annotated[str, "A move in UCI format."],
) -> Annotated[str, "Result of the move."]:
"""Make a move on the board."""
validate_turn(board, player)
newMove = Move.from_uci(move)
board.push(newMove)
# Print the move.
print("-" * 50)
print("Player:", player)
print("Move:", newMove.uci())
print("Thinking:", thinking)
print("Board:")
print(board.unicode(borders=True))
# Get the piece name.
piece = board.piece_at(newMove.to_square)
assert piece is not None
piece_symbol = piece.unicode_symbol()
piece_name = get_piece_name(piece.piece_type)
if piece_symbol.isupper():
piece_name = piece_name.capitalize()
return f"Moved {piece_name} ({piece_symbol}) from {SQUARE_NAMES[newMove.from_square]} to {SQUARE_NAMES[newMove.to_square]}."
def chess_game(runtime: AgentRuntime) -> GroupChat: # type: ignore
"""Create agents for a chess game and return the group chat."""
# Create the board.
board = Board()
# Create shared tools.
def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format."]:
return "Possible moves are: " + ", ".join([str(move) for move in board.legal_moves])
# Create tools for each player.
# @functools.wraps(get_legal_moves)
def get_legal_moves_black() -> str:
return get_legal_moves(board, "black")
get_legal_moves_tool = FunctionTool(get_legal_moves, description="Get legal moves.")
# @functools.wraps(get_legal_moves)
def get_legal_moves_white() -> str:
return get_legal_moves(board, "white")
def make_move(input: Annotated[str, "A move in UCI format."]) -> Annotated[str, "Result of the move."]:
move = Move.from_uci(input)
board.push(move)
print(board.unicode(borders=True))
# Get the piece name.
piece = board.piece_at(move.to_square)
assert piece is not None
piece_symbol = piece.unicode_symbol()
piece_name = get_piece_name(piece.piece_type)
if piece_symbol.isupper():
piece_name = piece_name.capitalize()
return f"Moved {piece_name} ({piece_symbol}) from {SQUARE_NAMES[move.from_square]} to {SQUARE_NAMES[move.to_square]}."
# @functools.wraps(make_move)
def make_move_black(
thinking: Annotated[str, "Thinking for the move"],
move: Annotated[str, "A move in UCI format"],
) -> str:
return make_move(board, "black", thinking, move)
make_move_tool = FunctionTool(make_move, description="Call this tool to make a move.")
# @functools.wraps(make_move)
def make_move_white(
thinking: Annotated[str, "Thinking for the move"],
move: Annotated[str, "A move in UCI format"],
) -> str:
return make_move(board, "white", thinking, move)
tools = [get_legal_moves_tool, make_move_tool]
def get_board_text() -> Annotated[str, "The current board state"]:
return get_board(board)
black_tools = [
FunctionTool(
get_legal_moves_black,
name="get_legal_moves",
description="Get legal moves.",
),
FunctionTool(
make_move_black,
name="make_move",
description="Make a move.",
),
FunctionTool(
get_board_text,
name="get_board",
description="Get the current board state.",
),
]
white_tools = [
FunctionTool(
get_legal_moves_white,
name="get_legal_moves",
description="Get legal moves.",
),
FunctionTool(
make_move_white,
name="make_move",
description="Make a move.",
),
FunctionTool(
get_board_text,
name="get_board",
description="Get the current board state.",
),
]
black = ChatCompletionAgent(
name="PlayerBlack",
@@ -70,12 +170,13 @@ def chess_game(runtime: AgentRuntime) -> GroupChat: # type: ignore
system_messages=[
SystemMessage(
content="You are a chess player and you play as black. "
"First call get_legal_moves() first, to get list of legal moves. "
"Then call make_move(move) to make a move."
"Use get_legal_moves() to get list of legal moves. "
"Use get_board() to get the current board state. "
"Think about your strategy and call make_move(thinking, move) to make a move."
),
],
model_client=OpenAI(model="gpt-4-turbo"),
tools=tools,
tools=black_tools,
)
white = ChatCompletionAgent(
name="PlayerWhite",
@@ -84,19 +185,21 @@ def chess_game(runtime: AgentRuntime) -> GroupChat: # type: ignore
system_messages=[
SystemMessage(
content="You are a chess player and you play as white. "
"First call get_legal_moves() first, to get list of legal moves. "
"Then call make_move(move) to make a move."
"Use get_legal_moves() to get list of legal moves. "
"Use get_board() to get the current board state. "
"Think about your strategy and call make_move(thinking, move) to make a move."
),
],
model_client=OpenAI(model="gpt-4-turbo"),
tools=tools,
tools=white_tools,
)
game_chat = GroupChat(
game_chat = TwoAgentChat(
name="ChessGame",
description="A chess game between two agents.",
runtime=runtime,
agents=[white, black],
num_rounds=10,
first_speaker=white,
second_speaker=black,
num_rounds=100,
output=ChessGameOutput(),
)
return game_chat
@@ -117,5 +220,10 @@ if __name__ == "__main__":
default="Please make a move.",
help="The initial message to send to the agent playing white.",
)
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")
args = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
asyncio.run(main(args.initial_message))

View File

@@ -1,133 +0,0 @@
"""This is an example of a chat with an OAI assistant agent.
You must have OPENAI_API_KEY set up in your environment to
run this example.
"""
from typing import Any
import openai
from agnext.application import SingleThreadedAgentRuntime
from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.patterns.group_chat import GroupChatOutput
from agnext.chat.patterns.two_agent_chat import TwoAgentChat
from agnext.chat.types import RespondNow, TextMessage
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import AgentRuntime, CancellationToken
from openai import AsyncAssistantEventHandler
from openai.types.beta import AssistantStreamEvent
from openai.types.beta.threads import Text, TextDelta
from openai.types.beta.threads.runs import RunStep, RunStepDelta
from typing_extensions import override
class TwoAgentChatOutput(GroupChatOutput): # type: ignore
def on_message_received(self, message: Any) -> None:
pass
def get_output(self) -> Any:
return None
def reset(self) -> None:
pass
sep = "-" * 50
class UserProxyAgent(BaseChatAgent, TypeRoutedAgent): # type: ignore
def __init__(self, name: str, runtime: AgentRuntime) -> None: # type: ignore
super().__init__(
name=name,
description="A human user",
runtime=runtime,
)
@message_handler() # type: ignore
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # type: ignore
# TODO: render image if message has image.
# print(f"{message.source}: {message.content}")
pass
@message_handler() # type: ignore
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage: # type: ignore
user_input = input(f"\n{sep}\nYou: ")
# TODO: add parsing for special commands e.g., upload files, exit, etc.
return TextMessage(content=user_input, source=self.name)
class EventHandler(AsyncAssistantEventHandler):
@override
async def on_event(self, event: AssistantStreamEvent) -> None:
if event.event == "thread.run.step.created":
details = event.data.step_details
if details.type == "tool_calls":
print("\nGenerating code to interpret:\n\n```python")
elif event.event == "thread.message.created":
print(f"{sep}\nAssistant:\n")
@override
async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
print(delta.value, end="", flush=True)
@override
async def on_run_step_done(self, run_step: RunStep) -> None:
details = run_step.step_details
if details.type == "tool_calls":
for tool in details.tool_calls:
if tool.type == "code_interpreter":
print("\n```\nExecuting code...")
@override
async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
details = delta.step_details
if details is not None and details.type == "tool_calls":
for tool in details.tool_calls or []:
if tool.type == "code_interpreter" and tool.code_interpreter and tool.code_interpreter.input:
print(tool.code_interpreter.input, end="", flush=True)
def assistant_chat(runtime: AgentRuntime) -> TwoAgentChat: # type: ignore
user = UserProxyAgent(name="User", runtime=runtime)
oai_assistant = openai.beta.assistants.create(
model="gpt-4-turbo",
description="An AI assistant that helps with everyday tasks.",
instructions="Help the user with their task.",
tools=[{"type": "code_interpreter"}],
)
thread = openai.beta.threads.create()
assistant = OpenAIAssistantAgent(
name="Assistant",
description="An AI assistant that helps with everyday tasks.",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=oai_assistant.id,
thread_id=thread.id,
assistant_event_handler_factory=lambda: EventHandler(),
)
return TwoAgentChat(
name="AssistantChat",
description="A chat with an AI assistant",
runtime=runtime,
initial_sender=user,
initial_recipient=assistant,
num_rounds=100,
output=TwoAgentChatOutput(),
)
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
chat = assistant_chat(runtime)
future = runtime.send_message(
TextMessage(content="Hello.", source="User"),
chat,
)
while not future.done():
await runtime.process_next()
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -34,6 +34,8 @@ dev = [
# Dependencies for the examples.
"chess",
"tavily-python",
"aiofiles",
"types-aiofiles",
]
docs = ["sphinx", "furo", "sphinxcontrib-apidoc", "myst-parser"]

View File

@@ -12,8 +12,8 @@ class TwoAgentChat(GroupChat):
name: str,
description: str,
runtime: AgentRuntime,
initial_sender: BaseChatAgent,
initial_recipient: BaseChatAgent,
first_speaker: BaseChatAgent,
second_speaker: BaseChatAgent,
num_rounds: int,
output: GroupChatOutput,
) -> None:
@@ -21,7 +21,7 @@ class TwoAgentChat(GroupChat):
name,
description,
runtime,
[initial_recipient, initial_sender],
[first_speaker, second_speaker],
num_rounds,
output,
)