Orchestrator Chat and OAI Assistant update (#31)

This commit is contained in:
Eric Zhu
2024-05-28 18:11:55 -07:00
committed by GitHub
parent ecbc3b7806
commit e3a2f79e65
4 changed files with 138 additions and 153 deletions

View File

@@ -3,13 +3,12 @@ import asyncio
import logging
import openai
from agnext.agent_components.model_client import OpenAI
from agnext.application_components import (
SingleThreadedAgentRuntime,
)
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
from agnext.chat.patterns.orchestrator import Orchestrator
from agnext.chat.patterns.orchestrator_chat import OrchestratorChat
from agnext.chat.types import TextMessage
from agnext.core._agent import Agent
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
@@ -38,20 +37,28 @@ class ConcatOutput(GroupChatOutput):
class LoggingHandler(DefaultInterventionHandler):
send_color = "\033[31m"
response_color = "\033[34m"
reset_color = "\033[0m"
@override
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]:
if sender is None:
print(f"Sending message to {recipient.name}: {message}")
print(f"{self.send_color}Sending message to {recipient.name}:{self.reset_color} {message}")
else:
print(f"Sending message from {sender.name} to {recipient.name}: {message}")
print(
f"{self.send_color}Sending message from {sender.name} to {recipient.name}:{self.reset_color} {message}"
)
return message
@override
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
if recipient is None:
print(f"Received response from {sender.name}: {message}")
print(f"{self.response_color}Received response from {sender.name}:{self.reset_color} {message}")
else:
print(f"Received response from {sender.name} to {recipient.name}: {message}")
print(
f"{self.response_color}Received response from {sender.name} to {recipient.name}:{self.reset_color} {message}"
)
return message
@@ -131,18 +138,46 @@ async def orchestrator(message: str) -> None:
thread_id=product_manager_oai_thread.id,
)
chat = Orchestrator(
"Manager",
"A software development team manager.",
runtime,
[developer, product_manager],
model_client=OpenAI(model="gpt-3.5-turbo"),
planner_oai_assistant = openai.beta.assistants.create(
model="gpt-4-turbo",
name="Planner",
instructions="You are a planner of complex tasks.",
)
planner_oai_thread = openai.beta.threads.create()
planner = OpenAIAssistantAgent(
name="Planner",
description="A planner that organizes and schedules tasks.",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=planner_oai_assistant.id,
thread_id=planner_oai_thread.id,
)
response = runtime.send_message(
TextMessage(content=message, source="customer"),
chat,
orchestrator_oai_assistant = openai.beta.assistants.create(
model="gpt-4-turbo",
name="Orchestrator",
instructions="You are an orchestrator that coordinates the team to complete a complex task.",
)
orchestrator_oai_thread = openai.beta.threads.create()
orchestrator = OpenAIAssistantAgent(
name="Orchestrator",
description="An orchestrator that coordinates the team.",
runtime=runtime,
client=openai.AsyncClient(),
assistant_id=orchestrator_oai_assistant.id,
thread_id=orchestrator_oai_thread.id,
)
chat = OrchestratorChat(
"Orchestrator Chat",
"A software development team.",
runtime,
orchestrator=orchestrator,
planner=planner,
specialists=[developer, product_manager],
)
response = runtime.send_message(TextMessage(content=message, source="Customer"), chat)
while not response.done():
await runtime.process_next()

View File

@@ -1,4 +1,4 @@
from typing import Callable, Dict
from typing import Callable, Dict, List
import openai
@@ -23,8 +23,6 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
self._client = client
self._assistant_id = assistant_id
self._thread_id = thread_id
# TODO: investigate why this is 1, as setting this to 0 causes the earlest message in the window to be ignored.
self._current_session_window_length = 1
self._tools = tools or {}
@message_handler(TextMessage)
@@ -36,32 +34,40 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
role="user",
metadata={"sender": message.source},
)
self._current_session_window_length += 1
@message_handler(Reset)
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Reset the current session window.
self._current_session_window_length = 1
# Get all messages in this thread.
all_msgs: List[str] = []
while True:
if not all_msgs:
msgs = await self._client.beta.threads.messages.list(self._thread_id)
else:
msgs = await self._client.beta.threads.messages.list(self._thread_id, after=all_msgs[-1])
for msg in msgs.data:
all_msgs.append(msg.id)
if not msgs.has_next_page():
break
# Delete all the messages.
for msg_id in all_msgs:
status = await self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
assert status.deleted is True
@message_handler(RespondNow)
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
# Handle response format.
# Create a run and wait until it finishes.
run = await self._client.beta.threads.runs.create_and_poll(
thread_id=self._thread_id,
assistant_id=self._assistant_id,
truncation_strategy={
"type": "last_messages",
"last_messages": self._current_session_window_length,
},
response_format=message.response_format,
)
if run.status != "completed":
# TODO: handle other statuses.
raise ValueError(f"Run did not complete successfully: {run}")
# Increment the current session window length.
self._current_session_window_length += 1
# Get the last message from the run.
response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1)
last_message_content = response.data[0].content

View File

@@ -1,40 +1,39 @@
import json
from typing import Any, List, Sequence, Tuple
from typing import Any, Sequence, Tuple
from ...agent_components.model_client import ModelClient
from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage
from ...core import AgentRuntime, CancellationToken
from ..agents.base import BaseChatAgent
from ..types import RespondNow, TextMessage
from ..types import Reset, RespondNow, TextMessage
class Orchestrator(BaseChatAgent, TypeRoutedAgent):
class OrchestratorChat(BaseChatAgent, TypeRoutedAgent):
def __init__(
self,
name: str,
description: str,
runtime: AgentRuntime,
agents: Sequence[BaseChatAgent],
model_client: ModelClient,
orchestrator: BaseChatAgent,
planner: BaseChatAgent,
specialists: Sequence[BaseChatAgent],
max_turns: int = 30,
max_stalled_turns_before_retry: int = 2,
max_retry_attempts: int = 1,
) -> None:
super().__init__(name, description, runtime)
self._agents = agents
self._model_client = model_client
self._orchestrator = orchestrator
self._planner = planner
self._specialists = specialists
self._max_turns = max_turns
self._max_stalled_turns_before_retry = max_stalled_turns_before_retry
self._max_retry_attempts_before_educated_guess = max_retry_attempts
self._history: List[TextMessage] = []
@message_handler(TextMessage)
async def on_text_message(
self,
message: TextMessage,
cancellation_token: CancellationToken,
) -> TextMessage | None:
) -> TextMessage:
# A task is received.
task = message.content
@@ -44,8 +43,11 @@ class Orchestrator(BaseChatAgent, TypeRoutedAgent):
# Main loop.
total_turns = 0
retry_attempts = 0
ledgers: List[List[LLMMessage]] = []
while total_turns < self._max_turns:
# Reset all agents.
for agent in [*self._specialists, self._orchestrator]:
self._send_message(Reset(), agent)
# Create the task specs.
task_specs = f"""
We are working to address the following user request:
@@ -64,30 +66,15 @@ Some additional points to consider:
{plan}
""".strip()
# Send the task specs to the team and signal a reset.
for agent in self._agents:
self._send_message(
TextMessage(
content=task_specs,
source=self.name,
),
agent,
)
# Create the ledger.
ledger: List[LLMMessage] = [
AssistantMessage(
content=task_specs,
source=self.name,
)
]
ledgers.append(ledger)
# Send the task specs to the orchestrator and specialists.
for agent in [*self._specialists, self._orchestrator]:
self._send_message(TextMessage(content=task_specs, source=self.name), agent)
# Inner loop.
stalled_turns = 0
while total_turns < self._max_turns:
# Reflect on the task.
data = await self._reflect_on_task(task, team, names, ledger, message.source)
data = await self._reflect_on_task(task, team, names, message.source)
# Check if the request is satisfied.
if data["is_request_satisfied"]["answer"]:
@@ -107,7 +94,7 @@ Some additional points to consider:
# In a retry, we need to rewrite the facts and the plan.
# Rewrite the facts.
facts = await self._rewrite_facts(facts, ledger, message.source)
facts = await self._rewrite_facts(facts, message.source)
# Increment the retry attempts.
retry_attempts += 1
@@ -115,7 +102,7 @@ Some additional points to consider:
# Check if we should just guess.
if retry_attempts > self._max_retry_attempts_before_educated_guess:
# Make an educated guess.
educated_guess = await self._educated_guess(facts, ledger, message.source)
educated_guess = await self._educated_guess(facts, message.source)
if educated_guess["has_educated_guesses"]["answer"]:
return TextMessage(
content=f"The task is addressed with an educated guess. {educated_guess['has_educated_guesses']['reason']}",
@@ -123,7 +110,7 @@ Some additional points to consider:
)
# Come up with a new plan.
plan = await self._rewrite_plan(team, ledger, message.source)
plan = await self._rewrite_plan(team, message.source)
# Exit the inner loop.
break
@@ -134,7 +121,7 @@ Some additional points to consider:
subtask = ""
# Update agents.
for agent in [agent for agent in self._agents]:
for agent in [*self._specialists, self._orchestrator]:
_ = await self._send_message(
TextMessage(content=subtask, source=self.name),
agent,
@@ -142,41 +129,20 @@ Some additional points to consider:
# Find the speaker.
try:
speaker = next(agent for agent in self._agents if agent.name == data["next_speaker"]["answer"])
speaker = next(agent for agent in self._specialists if agent.name == data["next_speaker"]["answer"])
except StopIteration as e:
raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e
# As speaker to speak.
# Ask speaker to speak.
speaker_response = await self._send_message(RespondNow(), speaker)
assert speaker_response is not None
# Update the ledger.
ledger.append(
AssistantMessage(
content=subtask,
source=self.name,
)
)
# Update all other agents with the speaker's response.
for agent in [agent for agent in self._agents if agent != speaker]:
_ = await self._send_message(
TextMessage(
content=speaker_response.content,
source=speaker_response.source,
),
agent,
for agent in [agent for agent in self._specialists if agent != speaker] + [self._orchestrator]:
self._send_message(
TextMessage(content=speaker_response.content, source=speaker_response.source), agent
)
# Update the ledger.
ledger.append(
UserMessage(
content=speaker_response.content,
source=speaker_response.source,
)
)
# Increment the total turns.
total_turns += 1
@@ -186,9 +152,12 @@ Some additional points to consider:
)
async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, str]:
# Reset planner.
self._send_message(Reset(), self._planner)
# A reusable description of the team.
team = "\n".join([agent.name + ": " + agent.description for agent in self._agents])
names = ", ".join([agent.name for agent in self._agents])
team = "\n".join([agent.name + ": " + agent.description for agent in self._specialists])
names = ", ".join([agent.name for agent in self._specialists])
# A place to store relevant facts.
facts = ""
@@ -218,19 +187,10 @@ When answering this survey, keep in mind that "facts" will typically be specific
4. EDUCATED GUESSES
""".strip()
starter_messages: List[LLMMessage] = [
UserMessage(
content=closed_book_prompt,
source=sender,
)
]
facts_response = await self._model_client.create(messages=starter_messages)
starter_messages.append(
AssistantMessage(
content=facts_response.content,
source=self.name,
)
)
# Ask the planner to obtain prior knowledge about facts.
self._send_message(TextMessage(content=closed_book_prompt, source=sender), self._planner)
facts_response = await self._send_message(RespondNow(), self._planner)
facts = str(facts_response.content)
# Make an initial plan
@@ -239,19 +199,10 @@ When answering this survey, keep in mind that "facts" will typically be specific
{team}
Based on the team composition, and known and unknown facts, please devise a short bullet-point plan for addressing the original request. Remember, there is no requirement to involve all team members -- a team member's particular expertise may not be needed for this task.""".strip()
starter_messages.append(
UserMessage(
content=plan_prompt,
source=sender,
)
)
plan_response = await self._model_client.create(messages=starter_messages)
starter_messages.append(
AssistantMessage(
content=plan_response.content,
source=self.name,
)
)
# Send second messag eto the planner.
self._send_message(TextMessage(content=plan_prompt, source=sender), self._planner)
plan_response = await self._send_message(RespondNow(), self._planner)
plan = str(plan_response.content)
return team, names, facts, plan
@@ -261,7 +212,6 @@ Based on the team composition, and known and unknown facts, please devise a shor
task: str,
team: str,
names: str,
ledger: List[LLMMessage],
sender: str,
) -> Any:
step_prompt = f"""
@@ -301,37 +251,28 @@ Please output an answer in pure JSON format according to the following schema. T
}}
}}
""".strip()
step_response = await self._model_client.create(
messages=ledger + [UserMessage(content=step_prompt, source=sender)],
extra_create_args={"response_format": {"type": "json_object"}},
# Send a message to the orchestrator.
self._send_message(TextMessage(content=step_prompt, source=sender), self._orchestrator)
# Request a response.
step_response = await self._send_message(
RespondNow(response_format={"type": "json_object"}), self._orchestrator
)
step_response_json = str(step_response.content)
# TODO: handle invalid JSON.
# TODO: use typed dictionary.
return json.loads(step_response_json)
return json.loads(step_response.content)
async def _rewrite_facts(self, facts: str, ledger: List[LLMMessage], sender: str) -> str:
async def _rewrite_facts(self, facts: str, sender: str) -> str:
new_facts_prompt = f"""It's clear we aren't making as much progress as we would like, but we may have learned something new. Please rewrite the following fact sheet, updating it to include anything new we have learned. This is also a good time to update educated guesses (please add or update at least one educated guess or hunch, and explain your reasoning).
{facts}
""".strip()
ledger.append(
UserMessage(
content=new_facts_prompt,
source=sender,
)
)
new_facts_response = await self._model_client.create(messages=ledger)
facts = str(new_facts_response.content)
ledger.append(
AssistantMessage(
content=facts,
source=self.name,
)
)
return facts
# Send a message to the orchestrator.
self._send_message(TextMessage(content=new_facts_prompt, source=sender), self._orchestrator)
# Request a response.
new_facts_response = await self._send_message(RespondNow(), self._orchestrator)
return str(new_facts_response.content)
async def _educated_guess(self, facts: str, ledger: List[LLMMessage], sender: str) -> Any:
async def _educated_guess(self, facts: str, sender: str) -> Any:
# Make an educated guess.
educated_guess_promt = f"""Given the following information
@@ -348,25 +289,24 @@ Please output an answer in pure JSON format according to the following schema. T
}}
}}
""".strip()
educated_guess_response = await self._model_client.create(
messages=ledger + [UserMessage(content=educated_guess_promt, source=sender)],
extra_create_args={"response_format": {"type": "json_object"}},
# Send a message to the orchestrator.
self._send_message(TextMessage(content=educated_guess_promt, source=sender), self._orchestrator)
# Request a response.
educated_guess_response = await self._send_message(
RespondNow(response_format={"type": "json_object"}), self._orchestrator
)
# TODO: handle invalid JSON.
# TODO: use typed dictionary.
return json.loads(str(educated_guess_response.content))
async def _rewrite_plan(self, team: str, ledger: List[LLMMessage], sender: str) -> str:
async def _rewrite_plan(self, team: str, sender: str) -> str:
new_plan_prompt = f"""Please come up with a new plan expressed in bullet points. Keep in mind the following team composition, and do not involve any other outside people in the plan -- we cannot contact anyone else.
Team membership:
{team}
""".strip()
ledger.append(
UserMessage(
content=new_plan_prompt,
source=sender,
)
)
new_plan_response = await self._model_client.create(messages=ledger)
# Send a message to the orchestrator.
self._send_message(TextMessage(content=new_plan_prompt, source=sender), self._orchestrator)
# Request a response.
new_plan_response = await self._send_message(RespondNow(), self._orchestrator)
return str(new_plan_response.content)

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Union
from typing import List, Literal, Union
from openai.types.beta import AssistantResponseFormatParam
from agnext.agent_components.image import Image
from agnext.agent_components.types import FunctionCall
@@ -42,7 +44,9 @@ class FunctionExecutionResultMessage(BaseMessage):
Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage]
class RespondNow: ...
@dataclass
class RespondNow:
response_format: Union[Literal["none", "auto"], AssistantResponseFormatParam] = "auto"
class Reset: ...