mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Orchestrator Chat and OAI Assistant update (#31)
This commit is contained in:
@@ -3,13 +3,12 @@ import asyncio
|
||||
import logging
|
||||
|
||||
import openai
|
||||
from agnext.agent_components.model_client import OpenAI
|
||||
from agnext.application_components import (
|
||||
SingleThreadedAgentRuntime,
|
||||
)
|
||||
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
|
||||
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
|
||||
from agnext.chat.patterns.orchestrator import Orchestrator
|
||||
from agnext.chat.patterns.orchestrator_chat import OrchestratorChat
|
||||
from agnext.chat.types import TextMessage
|
||||
from agnext.core._agent import Agent
|
||||
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
|
||||
@@ -38,20 +37,28 @@ class ConcatOutput(GroupChatOutput):
|
||||
|
||||
|
||||
class LoggingHandler(DefaultInterventionHandler):
|
||||
send_color = "\033[31m"
|
||||
response_color = "\033[34m"
|
||||
reset_color = "\033[0m"
|
||||
|
||||
@override
|
||||
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]:
|
||||
if sender is None:
|
||||
print(f"Sending message to {recipient.name}: {message}")
|
||||
print(f"{self.send_color}Sending message to {recipient.name}:{self.reset_color} {message}")
|
||||
else:
|
||||
print(f"Sending message from {sender.name} to {recipient.name}: {message}")
|
||||
print(
|
||||
f"{self.send_color}Sending message from {sender.name} to {recipient.name}:{self.reset_color} {message}"
|
||||
)
|
||||
return message
|
||||
|
||||
@override
|
||||
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
|
||||
if recipient is None:
|
||||
print(f"Received response from {sender.name}: {message}")
|
||||
print(f"{self.response_color}Received response from {sender.name}:{self.reset_color} {message}")
|
||||
else:
|
||||
print(f"Received response from {sender.name} to {recipient.name}: {message}")
|
||||
print(
|
||||
f"{self.response_color}Received response from {sender.name} to {recipient.name}:{self.reset_color} {message}"
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
@@ -131,18 +138,46 @@ async def orchestrator(message: str) -> None:
|
||||
thread_id=product_manager_oai_thread.id,
|
||||
)
|
||||
|
||||
chat = Orchestrator(
|
||||
"Manager",
|
||||
"A software development team manager.",
|
||||
runtime,
|
||||
[developer, product_manager],
|
||||
model_client=OpenAI(model="gpt-3.5-turbo"),
|
||||
planner_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-4-turbo",
|
||||
name="Planner",
|
||||
instructions="You are a planner of complex tasks.",
|
||||
)
|
||||
planner_oai_thread = openai.beta.threads.create()
|
||||
planner = OpenAIAssistantAgent(
|
||||
name="Planner",
|
||||
description="A planner that organizes and schedules tasks.",
|
||||
runtime=runtime,
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=planner_oai_assistant.id,
|
||||
thread_id=planner_oai_thread.id,
|
||||
)
|
||||
|
||||
response = runtime.send_message(
|
||||
TextMessage(content=message, source="customer"),
|
||||
chat,
|
||||
orchestrator_oai_assistant = openai.beta.assistants.create(
|
||||
model="gpt-4-turbo",
|
||||
name="Orchestrator",
|
||||
instructions="You are an orchestrator that coordinates the team to complete a complex task.",
|
||||
)
|
||||
orchestrator_oai_thread = openai.beta.threads.create()
|
||||
orchestrator = OpenAIAssistantAgent(
|
||||
name="Orchestrator",
|
||||
description="An orchestrator that coordinates the team.",
|
||||
runtime=runtime,
|
||||
client=openai.AsyncClient(),
|
||||
assistant_id=orchestrator_oai_assistant.id,
|
||||
thread_id=orchestrator_oai_thread.id,
|
||||
)
|
||||
|
||||
chat = OrchestratorChat(
|
||||
"Orchestrator Chat",
|
||||
"A software development team.",
|
||||
runtime,
|
||||
orchestrator=orchestrator,
|
||||
planner=planner,
|
||||
specialists=[developer, product_manager],
|
||||
)
|
||||
|
||||
response = runtime.send_message(TextMessage(content=message, source="Customer"), chat)
|
||||
|
||||
while not response.done():
|
||||
await runtime.process_next()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
import openai
|
||||
|
||||
@@ -23,8 +23,6 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
self._client = client
|
||||
self._assistant_id = assistant_id
|
||||
self._thread_id = thread_id
|
||||
# TODO: investigate why this is 1, as setting this to 0 causes the earlest message in the window to be ignored.
|
||||
self._current_session_window_length = 1
|
||||
self._tools = tools or {}
|
||||
|
||||
@message_handler(TextMessage)
|
||||
@@ -36,32 +34,40 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
role="user",
|
||||
metadata={"sender": message.source},
|
||||
)
|
||||
self._current_session_window_length += 1
|
||||
|
||||
@message_handler(Reset)
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
# Reset the current session window.
|
||||
self._current_session_window_length = 1
|
||||
# Get all messages in this thread.
|
||||
all_msgs: List[str] = []
|
||||
while True:
|
||||
if not all_msgs:
|
||||
msgs = await self._client.beta.threads.messages.list(self._thread_id)
|
||||
else:
|
||||
msgs = await self._client.beta.threads.messages.list(self._thread_id, after=all_msgs[-1])
|
||||
for msg in msgs.data:
|
||||
all_msgs.append(msg.id)
|
||||
if not msgs.has_next_page():
|
||||
break
|
||||
# Delete all the messages.
|
||||
for msg_id in all_msgs:
|
||||
status = await self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
|
||||
assert status.deleted is True
|
||||
|
||||
@message_handler(RespondNow)
|
||||
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
|
||||
# Handle response format.
|
||||
|
||||
# Create a run and wait until it finishes.
|
||||
run = await self._client.beta.threads.runs.create_and_poll(
|
||||
thread_id=self._thread_id,
|
||||
assistant_id=self._assistant_id,
|
||||
truncation_strategy={
|
||||
"type": "last_messages",
|
||||
"last_messages": self._current_session_window_length,
|
||||
},
|
||||
response_format=message.response_format,
|
||||
)
|
||||
|
||||
if run.status != "completed":
|
||||
# TODO: handle other statuses.
|
||||
raise ValueError(f"Run did not complete successfully: {run}")
|
||||
|
||||
# Increment the current session window length.
|
||||
self._current_session_window_length += 1
|
||||
|
||||
# Get the last message from the run.
|
||||
response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1)
|
||||
last_message_content = response.data[0].content
|
||||
|
||||
@@ -1,40 +1,39 @@
|
||||
import json
|
||||
from typing import Any, List, Sequence, Tuple
|
||||
from typing import Any, Sequence, Tuple
|
||||
|
||||
from ...agent_components.model_client import ModelClient
|
||||
from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
||||
from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage
|
||||
from ...core import AgentRuntime, CancellationToken
|
||||
from ..agents.base import BaseChatAgent
|
||||
from ..types import RespondNow, TextMessage
|
||||
from ..types import Reset, RespondNow, TextMessage
|
||||
|
||||
|
||||
class Orchestrator(BaseChatAgent, TypeRoutedAgent):
|
||||
class OrchestratorChat(BaseChatAgent, TypeRoutedAgent):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
runtime: AgentRuntime,
|
||||
agents: Sequence[BaseChatAgent],
|
||||
model_client: ModelClient,
|
||||
orchestrator: BaseChatAgent,
|
||||
planner: BaseChatAgent,
|
||||
specialists: Sequence[BaseChatAgent],
|
||||
max_turns: int = 30,
|
||||
max_stalled_turns_before_retry: int = 2,
|
||||
max_retry_attempts: int = 1,
|
||||
) -> None:
|
||||
super().__init__(name, description, runtime)
|
||||
self._agents = agents
|
||||
self._model_client = model_client
|
||||
self._orchestrator = orchestrator
|
||||
self._planner = planner
|
||||
self._specialists = specialists
|
||||
self._max_turns = max_turns
|
||||
self._max_stalled_turns_before_retry = max_stalled_turns_before_retry
|
||||
self._max_retry_attempts_before_educated_guess = max_retry_attempts
|
||||
self._history: List[TextMessage] = []
|
||||
|
||||
@message_handler(TextMessage)
|
||||
async def on_text_message(
|
||||
self,
|
||||
message: TextMessage,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> TextMessage | None:
|
||||
) -> TextMessage:
|
||||
# A task is received.
|
||||
task = message.content
|
||||
|
||||
@@ -44,8 +43,11 @@ class Orchestrator(BaseChatAgent, TypeRoutedAgent):
|
||||
# Main loop.
|
||||
total_turns = 0
|
||||
retry_attempts = 0
|
||||
ledgers: List[List[LLMMessage]] = []
|
||||
while total_turns < self._max_turns:
|
||||
# Reset all agents.
|
||||
for agent in [*self._specialists, self._orchestrator]:
|
||||
self._send_message(Reset(), agent)
|
||||
|
||||
# Create the task specs.
|
||||
task_specs = f"""
|
||||
We are working to address the following user request:
|
||||
@@ -64,30 +66,15 @@ Some additional points to consider:
|
||||
{plan}
|
||||
""".strip()
|
||||
|
||||
# Send the task specs to the team and signal a reset.
|
||||
for agent in self._agents:
|
||||
self._send_message(
|
||||
TextMessage(
|
||||
content=task_specs,
|
||||
source=self.name,
|
||||
),
|
||||
agent,
|
||||
)
|
||||
|
||||
# Create the ledger.
|
||||
ledger: List[LLMMessage] = [
|
||||
AssistantMessage(
|
||||
content=task_specs,
|
||||
source=self.name,
|
||||
)
|
||||
]
|
||||
ledgers.append(ledger)
|
||||
# Send the task specs to the orchestrator and specialists.
|
||||
for agent in [*self._specialists, self._orchestrator]:
|
||||
self._send_message(TextMessage(content=task_specs, source=self.name), agent)
|
||||
|
||||
# Inner loop.
|
||||
stalled_turns = 0
|
||||
while total_turns < self._max_turns:
|
||||
# Reflect on the task.
|
||||
data = await self._reflect_on_task(task, team, names, ledger, message.source)
|
||||
data = await self._reflect_on_task(task, team, names, message.source)
|
||||
|
||||
# Check if the request is satisfied.
|
||||
if data["is_request_satisfied"]["answer"]:
|
||||
@@ -107,7 +94,7 @@ Some additional points to consider:
|
||||
# In a retry, we need to rewrite the facts and the plan.
|
||||
|
||||
# Rewrite the facts.
|
||||
facts = await self._rewrite_facts(facts, ledger, message.source)
|
||||
facts = await self._rewrite_facts(facts, message.source)
|
||||
|
||||
# Increment the retry attempts.
|
||||
retry_attempts += 1
|
||||
@@ -115,7 +102,7 @@ Some additional points to consider:
|
||||
# Check if we should just guess.
|
||||
if retry_attempts > self._max_retry_attempts_before_educated_guess:
|
||||
# Make an educated guess.
|
||||
educated_guess = await self._educated_guess(facts, ledger, message.source)
|
||||
educated_guess = await self._educated_guess(facts, message.source)
|
||||
if educated_guess["has_educated_guesses"]["answer"]:
|
||||
return TextMessage(
|
||||
content=f"The task is addressed with an educated guess. {educated_guess['has_educated_guesses']['reason']}",
|
||||
@@ -123,7 +110,7 @@ Some additional points to consider:
|
||||
)
|
||||
|
||||
# Come up with a new plan.
|
||||
plan = await self._rewrite_plan(team, ledger, message.source)
|
||||
plan = await self._rewrite_plan(team, message.source)
|
||||
|
||||
# Exit the inner loop.
|
||||
break
|
||||
@@ -134,7 +121,7 @@ Some additional points to consider:
|
||||
subtask = ""
|
||||
|
||||
# Update agents.
|
||||
for agent in [agent for agent in self._agents]:
|
||||
for agent in [*self._specialists, self._orchestrator]:
|
||||
_ = await self._send_message(
|
||||
TextMessage(content=subtask, source=self.name),
|
||||
agent,
|
||||
@@ -142,41 +129,20 @@ Some additional points to consider:
|
||||
|
||||
# Find the speaker.
|
||||
try:
|
||||
speaker = next(agent for agent in self._agents if agent.name == data["next_speaker"]["answer"])
|
||||
speaker = next(agent for agent in self._specialists if agent.name == data["next_speaker"]["answer"])
|
||||
except StopIteration as e:
|
||||
raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e
|
||||
|
||||
# As speaker to speak.
|
||||
# Ask speaker to speak.
|
||||
speaker_response = await self._send_message(RespondNow(), speaker)
|
||||
|
||||
assert speaker_response is not None
|
||||
|
||||
# Update the ledger.
|
||||
ledger.append(
|
||||
AssistantMessage(
|
||||
content=subtask,
|
||||
source=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
# Update all other agents with the speaker's response.
|
||||
for agent in [agent for agent in self._agents if agent != speaker]:
|
||||
_ = await self._send_message(
|
||||
TextMessage(
|
||||
content=speaker_response.content,
|
||||
source=speaker_response.source,
|
||||
),
|
||||
agent,
|
||||
for agent in [agent for agent in self._specialists if agent != speaker] + [self._orchestrator]:
|
||||
self._send_message(
|
||||
TextMessage(content=speaker_response.content, source=speaker_response.source), agent
|
||||
)
|
||||
|
||||
# Update the ledger.
|
||||
ledger.append(
|
||||
UserMessage(
|
||||
content=speaker_response.content,
|
||||
source=speaker_response.source,
|
||||
)
|
||||
)
|
||||
|
||||
# Increment the total turns.
|
||||
total_turns += 1
|
||||
|
||||
@@ -186,9 +152,12 @@ Some additional points to consider:
|
||||
)
|
||||
|
||||
async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, str]:
|
||||
# Reset planner.
|
||||
self._send_message(Reset(), self._planner)
|
||||
|
||||
# A reusable description of the team.
|
||||
team = "\n".join([agent.name + ": " + agent.description for agent in self._agents])
|
||||
names = ", ".join([agent.name for agent in self._agents])
|
||||
team = "\n".join([agent.name + ": " + agent.description for agent in self._specialists])
|
||||
names = ", ".join([agent.name for agent in self._specialists])
|
||||
|
||||
# A place to store relevant facts.
|
||||
facts = ""
|
||||
@@ -218,19 +187,10 @@ When answering this survey, keep in mind that "facts" will typically be specific
|
||||
4. EDUCATED GUESSES
|
||||
""".strip()
|
||||
|
||||
starter_messages: List[LLMMessage] = [
|
||||
UserMessage(
|
||||
content=closed_book_prompt,
|
||||
source=sender,
|
||||
)
|
||||
]
|
||||
facts_response = await self._model_client.create(messages=starter_messages)
|
||||
starter_messages.append(
|
||||
AssistantMessage(
|
||||
content=facts_response.content,
|
||||
source=self.name,
|
||||
)
|
||||
)
|
||||
# Ask the planner to obtain prior knowledge about facts.
|
||||
self._send_message(TextMessage(content=closed_book_prompt, source=sender), self._planner)
|
||||
facts_response = await self._send_message(RespondNow(), self._planner)
|
||||
|
||||
facts = str(facts_response.content)
|
||||
|
||||
# Make an initial plan
|
||||
@@ -239,19 +199,10 @@ When answering this survey, keep in mind that "facts" will typically be specific
|
||||
{team}
|
||||
|
||||
Based on the team composition, and known and unknown facts, please devise a short bullet-point plan for addressing the original request. Remember, there is no requirement to involve all team members -- a team member's particular expertise may not be needed for this task.""".strip()
|
||||
starter_messages.append(
|
||||
UserMessage(
|
||||
content=plan_prompt,
|
||||
source=sender,
|
||||
)
|
||||
)
|
||||
plan_response = await self._model_client.create(messages=starter_messages)
|
||||
starter_messages.append(
|
||||
AssistantMessage(
|
||||
content=plan_response.content,
|
||||
source=self.name,
|
||||
)
|
||||
)
|
||||
|
||||
# Send second messag eto the planner.
|
||||
self._send_message(TextMessage(content=plan_prompt, source=sender), self._planner)
|
||||
plan_response = await self._send_message(RespondNow(), self._planner)
|
||||
plan = str(plan_response.content)
|
||||
|
||||
return team, names, facts, plan
|
||||
@@ -261,7 +212,6 @@ Based on the team composition, and known and unknown facts, please devise a shor
|
||||
task: str,
|
||||
team: str,
|
||||
names: str,
|
||||
ledger: List[LLMMessage],
|
||||
sender: str,
|
||||
) -> Any:
|
||||
step_prompt = f"""
|
||||
@@ -301,37 +251,28 @@ Please output an answer in pure JSON format according to the following schema. T
|
||||
}}
|
||||
}}
|
||||
""".strip()
|
||||
step_response = await self._model_client.create(
|
||||
messages=ledger + [UserMessage(content=step_prompt, source=sender)],
|
||||
extra_create_args={"response_format": {"type": "json_object"}},
|
||||
# Send a message to the orchestrator.
|
||||
self._send_message(TextMessage(content=step_prompt, source=sender), self._orchestrator)
|
||||
# Request a response.
|
||||
step_response = await self._send_message(
|
||||
RespondNow(response_format={"type": "json_object"}), self._orchestrator
|
||||
)
|
||||
step_response_json = str(step_response.content)
|
||||
# TODO: handle invalid JSON.
|
||||
# TODO: use typed dictionary.
|
||||
return json.loads(step_response_json)
|
||||
return json.loads(step_response.content)
|
||||
|
||||
async def _rewrite_facts(self, facts: str, ledger: List[LLMMessage], sender: str) -> str:
|
||||
async def _rewrite_facts(self, facts: str, sender: str) -> str:
|
||||
new_facts_prompt = f"""It's clear we aren't making as much progress as we would like, but we may have learned something new. Please rewrite the following fact sheet, updating it to include anything new we have learned. This is also a good time to update educated guesses (please add or update at least one educated guess or hunch, and explain your reasoning).
|
||||
|
||||
{facts}
|
||||
""".strip()
|
||||
ledger.append(
|
||||
UserMessage(
|
||||
content=new_facts_prompt,
|
||||
source=sender,
|
||||
)
|
||||
)
|
||||
new_facts_response = await self._model_client.create(messages=ledger)
|
||||
facts = str(new_facts_response.content)
|
||||
ledger.append(
|
||||
AssistantMessage(
|
||||
content=facts,
|
||||
source=self.name,
|
||||
)
|
||||
)
|
||||
return facts
|
||||
# Send a message to the orchestrator.
|
||||
self._send_message(TextMessage(content=new_facts_prompt, source=sender), self._orchestrator)
|
||||
# Request a response.
|
||||
new_facts_response = await self._send_message(RespondNow(), self._orchestrator)
|
||||
return str(new_facts_response.content)
|
||||
|
||||
async def _educated_guess(self, facts: str, ledger: List[LLMMessage], sender: str) -> Any:
|
||||
async def _educated_guess(self, facts: str, sender: str) -> Any:
|
||||
# Make an educated guess.
|
||||
educated_guess_promt = f"""Given the following information
|
||||
|
||||
@@ -348,25 +289,24 @@ Please output an answer in pure JSON format according to the following schema. T
|
||||
}}
|
||||
}}
|
||||
""".strip()
|
||||
educated_guess_response = await self._model_client.create(
|
||||
messages=ledger + [UserMessage(content=educated_guess_promt, source=sender)],
|
||||
extra_create_args={"response_format": {"type": "json_object"}},
|
||||
# Send a message to the orchestrator.
|
||||
self._send_message(TextMessage(content=educated_guess_promt, source=sender), self._orchestrator)
|
||||
# Request a response.
|
||||
educated_guess_response = await self._send_message(
|
||||
RespondNow(response_format={"type": "json_object"}), self._orchestrator
|
||||
)
|
||||
# TODO: handle invalid JSON.
|
||||
# TODO: use typed dictionary.
|
||||
return json.loads(str(educated_guess_response.content))
|
||||
|
||||
async def _rewrite_plan(self, team: str, ledger: List[LLMMessage], sender: str) -> str:
|
||||
async def _rewrite_plan(self, team: str, sender: str) -> str:
|
||||
new_plan_prompt = f"""Please come up with a new plan expressed in bullet points. Keep in mind the following team composition, and do not involve any other outside people in the plan -- we cannot contact anyone else.
|
||||
|
||||
Team membership:
|
||||
{team}
|
||||
""".strip()
|
||||
ledger.append(
|
||||
UserMessage(
|
||||
content=new_plan_prompt,
|
||||
source=sender,
|
||||
)
|
||||
)
|
||||
new_plan_response = await self._model_client.create(messages=ledger)
|
||||
# Send a message to the orchestrator.
|
||||
self._send_message(TextMessage(content=new_plan_prompt, source=sender), self._orchestrator)
|
||||
# Request a response.
|
||||
new_plan_response = await self._send_message(RespondNow(), self._orchestrator)
|
||||
return str(new_plan_response.content)
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from openai.types.beta import AssistantResponseFormatParam
|
||||
|
||||
from agnext.agent_components.image import Image
|
||||
from agnext.agent_components.types import FunctionCall
|
||||
@@ -42,7 +44,9 @@ class FunctionExecutionResultMessage(BaseMessage):
|
||||
Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage]
|
||||
|
||||
|
||||
class RespondNow: ...
|
||||
@dataclass
|
||||
class RespondNow:
|
||||
response_format: Union[Literal["none", "auto"], AssistantResponseFormatParam] = "auto"
|
||||
|
||||
|
||||
class Reset: ...
|
||||
|
||||
Reference in New Issue
Block a user