Make groupchat & generation async, actually (#543)

* make groupchat & generation async actually

* factored out func call pre-select; updated indecies

* fixed code format issue

* mark prepare agents subset as internal

* func renaming

* func inputs

* return agents

* Update test/agentchat/test_async.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Update notebook/agentchat_stream.ipynb

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>
This commit is contained in:
kittyandrew
2023-12-09 22:50:36 +02:00
committed by GitHub
parent 379d7bd687
commit 6e2387192f
4 changed files with 67 additions and 15 deletions

View File

@@ -1,5 +1,6 @@
import asyncio
import copy
import functools
import json
import logging
from collections import defaultdict
@@ -133,9 +134,10 @@ class ConversableAgent(Agent):
self._reply_func_list = []
self.reply_at_receive = defaultdict(bool)
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
self.register_reply([Agent, None], ConversableAgent.generate_async_function_call_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_function_call_reply)
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
self.register_reply([Agent, None], ConversableAgent.a_check_termination_and_human_reply)
@@ -631,6 +633,17 @@ class ConversableAgent(Agent):
)
return True, client.extract_text_or_function_call(response)[0]
async def a_generate_oai_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""Generate a reply using autogen.oai asynchronously."""
return await asyncio.get_event_loop().run_in_executor(
None, functools.partial(self.generate_oai_reply, messages=messages, sender=sender, config=config)
)
def generate_code_execution_reply(
self,
messages: Optional[List[Dict]] = None,
@@ -697,7 +710,7 @@ class ConversableAgent(Agent):
return True, func_return
return False, None
async def generate_async_function_call_reply(
async def a_generate_function_call_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,

View File

@@ -3,7 +3,7 @@ import random
import re
import sys
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Tuple
from ..code_utils import content_str
from .agent import Agent
@@ -118,8 +118,7 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
print(f"Invalid input. Please enter a number between 1 and {_n_agents}.")
return None
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker."""
def _prepare_and_select_agents(self, last_speaker: Agent) -> Tuple[Optional[Agent], List[Agent]]:
if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
raise ValueError(
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
@@ -148,30 +147,35 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
]
if len(agents) == 1:
# only one agent can execute the function
return agents[0]
return agents[0], agents
elif not agents:
# find all the agents with function_map
agents = [agent for agent in self.agents if agent.function_map]
if len(agents) == 1:
return agents[0]
return agents[0], agents
elif not agents:
raise ValueError(
f"No agent can execute the function {self.messages[-1]['name']}. "
"Please check the function_map of the agents."
)
# remove the last speaker from the list to avoid selecting the same speaker if allow_repeat_speaker is False
agents = agents if self.allow_repeat_speaker else [agent for agent in agents if agent != last_speaker]
if self.speaker_selection_method.lower() == "manual":
selected_agent = self.manual_select_speaker(agents)
if selected_agent:
return selected_agent
elif self.speaker_selection_method.lower() == "round_robin":
return self.next_agent(last_speaker, agents)
selected_agent = self.next_agent(last_speaker, agents)
elif self.speaker_selection_method.lower() == "random":
return random.choice(agents)
selected_agent = random.choice(agents)
else:
selected_agent = None
return selected_agent, agents
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker."""
selected_agent, agents = self._prepare_and_select_agents(last_speaker)
if selected_agent:
return selected_agent
# auto speaker selection
selector.update_system_message(self.select_speaker_msg(agents))
context = self.messages + [{"role": "system", "content": self.select_speaker_prompt(agents)}]
@@ -196,6 +200,41 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
except ValueError:
return self.next_agent(last_speaker, agents)
async def a_select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker."""
selected_agent, agents = self._prepare_and_select_agents(last_speaker)
if selected_agent:
return selected_agent
# auto speaker selection
selector.update_system_message(self.select_speaker_msg(agents))
final, name = await selector.a_generate_oai_reply(
self.messages
+ [
{
"role": "system",
"content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.",
}
]
)
if not final:
# the LLM client is None, thus no reply is generated. Use round robin instead.
return self.next_agent(last_speaker, agents)
# If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified
mentions = self._mentioned_agents(name, agents)
if len(mentions) == 1:
name = next(iter(mentions))
else:
logger.warning(
f"GroupChat select_speaker failed to resolve the next speaker's name. This is because the speaker selection OAI call returned:\n{name}"
)
# Return the result
try:
return self.agent_by_name(name)
except ValueError:
return self.next_agent(last_speaker, agents)
def _participant_roles(self, agents: List[Agent] = None) -> str:
# Default to all agents registered
if agents is None:
@@ -342,7 +381,7 @@ class GroupChatManager(ConversableAgent):
break
try:
# select the next speaker
speaker = groupchat.select_speaker(speaker, self)
speaker = await groupchat.a_select_speaker(speaker, self)
# let the speaker speak
reply = await speaker.a_generate_reply(sender=self)
except KeyboardInterrupt:

View File

@@ -238,7 +238,7 @@
" )\n",
" return False, None\n",
"\n",
"user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, 1, config={\"news_stream\": data})"
"user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, position=2, config={\"news_stream\": data})"
]
},
{

View File

@@ -146,7 +146,7 @@ async def test_stream():
)
return False, None
user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, 1, config={"news_stream": data})
user_proxy.register_reply(autogen.AssistantAgent, add_data_reply, position=2, config={"news_stream": data})
await user_proxy.a_initiate_chat(
assistant,