Async version of multiple sequential chat (#1724)

* async_initiate_chats init commit

* Fix a_get_human_input bug

* Add agentchat_multi_task_async_chats.ipynb with concurrent exampls.

* Addess the comments, Update unit test

* Add agentchat_multi_task_async_chats.ipynb to Examples.md

* Fix type for Python 3.8

---------

Co-authored-by: Qingyun Wu <qingyun0327@gmail.com>
This commit is contained in:
Aristo
2024-02-21 11:33:33 -08:00
committed by GitHub
parent a34e4cc515
commit a4ab4cc9ae
5 changed files with 2152 additions and 35 deletions

View File

@@ -1,5 +1,7 @@
import asyncio
import logging
from typing import Dict, List, Any
from collections import defaultdict
from typing import Dict, List, Any, Set, Tuple
from dataclasses import dataclass
from .utils import consolidate_chat_info
import warnings
@@ -13,12 +15,15 @@ except ImportError:
logger = logging.getLogger(__name__)
Prerequisite = Tuple[int, int]
@dataclass
class ChatResult:
"""(Experimental) The result of a chat. Almost certain to be changed."""
chat_id: int = None
"""chat id"""
chat_history: List[Dict[str, any]] = None
"""The chat history."""
summary: str = None
@@ -29,6 +34,103 @@ class ChatResult:
"""A list of human input solicited during the chat."""
def _validate_recipients(chat_queue: List[Dict[str, Any]]) -> None:
"""
Validate recipients exits and warn repetitive recipients.
"""
receipts_set = set()
for chat_info in chat_queue:
assert "recipient" in chat_info, "recipient must be provided."
receipts_set.add(chat_info["recipient"])
if len(receipts_set) < len(chat_queue):
warnings.warn(
"Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.",
UserWarning,
)
def __create_async_prerequisites(chat_queue: List[Dict[str, Any]]) -> List[Prerequisite]:
"""
Create list of Prerequisite (prerequisite_chat_id, chat_id)
"""
prerequisites = []
for chat_info in chat_queue:
if "chat_id" not in chat_info:
raise ValueError("Each chat must have a unique id for async multi-chat execution.")
chat_id = chat_info["chat_id"]
pre_chats = chat_info.get("prerequisites", [])
for pre_chat_id in pre_chats:
if not isinstance(pre_chat_id, int):
raise ValueError("Prerequisite chat id is not int.")
prerequisites.append((chat_id, pre_chat_id))
return prerequisites
def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite]) -> List[int]:
"""Find chat order for async execution based on the prerequisite chats
args:
num_chats: number of chats
prerequisites: List of Prerequisite (prerequisite_chat_id, chat_id)
returns:
list: a list of chat_id in order.
"""
edges = defaultdict(set)
indegree = defaultdict(int)
for pair in prerequisites:
chat, pre = pair[0], pair[1]
if chat not in edges[pre]:
indegree[chat] += 1
edges[pre].add(chat)
bfs = [i for i in chat_ids if i not in indegree]
chat_order = []
steps = len(indegree)
for _ in range(steps + 1):
if not bfs:
break
chat_order.extend(bfs)
nxt = []
for node in bfs:
if node in edges:
for course in edges[node]:
indegree[course] -= 1
if indegree[course] == 0:
nxt.append(course)
indegree.pop(course)
edges.pop(node)
bfs = nxt
if indegree:
return []
return chat_order
def __post_carryover_processing(chat_info: Dict[str, Any]):
if "message" not in chat_info:
warnings.warn(
"message is not provided in a chat_queue entry. input() will be called to get the initial message.",
UserWarning,
)
print_carryover = (
("\n").join([t for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
print(
colored(
"Start a new chat with the following message: \n"
+ chat_info.get("message")
+ "\n\nWith the following carryover: \n"
+ print_carryover,
"blue",
),
flush=True,
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
"""Initiate a list of chats.
@@ -71,15 +173,7 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
"""
consolidate_chat_info(chat_queue)
receipts_set = set()
for chat_info in chat_queue:
assert "recipient" in chat_info, "recipient must be provided."
receipts_set.add(chat_info["recipient"])
if len(receipts_set) < len(chat_queue):
warnings.warn(
"Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.",
UserWarning,
)
_validate_recipients(chat_queue)
current_chat_queue = chat_queue.copy()
finished_chats = []
while current_chat_queue:
@@ -88,30 +182,48 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [r.summary for r in finished_chats]
if "message" not in chat_info:
warnings.warn(
"message is not provided in a chat_queue entry. input() will be called to get the initial message.",
UserWarning,
)
chat_info["recipient"]
print_carryover = (
("\n").join([t for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
print(
colored(
"Start a new chat with the following message: \n"
+ chat_info.get("message")
+ "\n\nWith the following carryover: \n"
+ print_carryover,
"blue",
),
flush=True,
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res = sender.initiate_chat(**chat_info)
finished_chats.append(chat_res)
return finished_chats
async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
"""(async) Initiate a list of chats.
args:
Please refer to `initiate_chats`.
returns:
(Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
"""
consolidate_chat_info(chat_queue)
_validate_recipients(chat_queue)
chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue}
num_chats = chat_book.keys()
prerequisites = __create_async_prerequisites(chat_queue)
chat_order_by_id = __find_async_chat_order(num_chats, prerequisites)
finished_chats = dict()
for chat_id in chat_order_by_id:
chat_info = chat_book[chat_id]
condition = asyncio.Condition()
prerequisite_chat_ids = chat_info.get("prerequisites", [])
async with condition:
await condition.wait_for(lambda: all([id in finished_chats for id in prerequisite_chat_ids]))
# Do the actual work here.
_chat_carryover = chat_info.get("carryover", [])
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [
finished_chats[pre_id].summary for pre_id in prerequisite_chat_ids
]
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res = await sender.a_initiate_chat(**chat_info)
chat_res.chat_id = chat_id
finished_chats[chat_id] = chat_res
return finished_chats

View File

@@ -26,7 +26,7 @@ from ..code_utils import (
infer_lang,
)
from .utils import gather_usage_summary, consolidate_chat_info
from .chat import ChatResult, initiate_chats
from .chat import ChatResult, initiate_chats, a_initiate_chats
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
@@ -985,6 +985,13 @@ class ConversableAgent(LLMAgent):
self._finished_chats = initiate_chats(_chat_queue)
return self._finished_chats
async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
_chat_queue = chat_queue.copy()
for chat_info in _chat_queue:
chat_info["sender"] = self
self._finished_chats = await a_initiate_chats(_chat_queue)
return self._finished_chats
def get_chat_results(self, chat_index: Optional[int] = None) -> Union[List[ChatResult], ChatResult]:
"""A summary from the finished chats of particular agents."""
if chat_index is not None:
@@ -1766,7 +1773,7 @@ class ConversableAgent(LLMAgent):
str: human input.
"""
reply = input(prompt)
self._human_inputs.append(reply)
self._human_input.append(reply)
return reply
def run_code(self, code, **kwargs):

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,116 @@
from autogen import AssistantAgent, UserProxyAgent
from autogen import GroupChat, GroupChatManager
import asyncio
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
import pytest
from conftest import skip_openai
import autogen
from typing import Literal
from typing_extensions import Annotated
from autogen import initiate_chats
@pytest.mark.skipif(skip_openai, reason="requested to skip openai tests")
@pytest.mark.asyncio
async def test_async_chats():
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
file_location=KEY_LOC,
)
financial_tasks = [
"""What are the full names of NVDA and TESLA.""",
"""Get their stock price.""",
"""Analyze pros and cons. Keep it short.""",
]
writing_tasks = ["""Develop a short but engaging blog post using any information provided."""]
financial_assistant_1 = AssistantAgent(
name="Financial_assistant_1",
llm_config={"config_list": config_list},
)
financial_assistant_2 = AssistantAgent(
name="Financial_assistant_2",
llm_config={"config_list": config_list},
)
writer = AssistantAgent(
name="Writer",
llm_config={"config_list": config_list},
is_termination_msg=lambda x: x.get("content", "").find("TERMINATE") >= 0,
system_message="""
You are a professional writer, known for
your insightful and engaging articles.
You transform complex concepts into compelling narratives.
Reply "TERMINATE" in the end when everything is done.
""",
)
user = UserProxyAgent(
name="User",
human_input_mode="NEVER",
is_termination_msg=lambda x: x.get("content", "").find("TERMINATE") >= 0,
code_execution_config={
"last_n_messages": 1,
"work_dir": "tasks",
"use_docker": False,
}, # Please set use_docker=True if docker is available to run the generated code. Using docker is safer than running the generated code directly.
)
def my_summary_method(recipient, sender):
return recipient.chat_messages[sender][0].get("content", "")
chat_res = await user.a_initiate_chats(
[
{
"chat_id": 1,
"recipient": financial_assistant_1,
"message": financial_tasks[0],
"silent": False,
"summary_method": my_summary_method,
},
{
"chat_id": 2,
"prerequisites": [1],
"recipient": financial_assistant_2,
"message": financial_tasks[1],
"silent": True,
"summary_method": "reflection_with_llm",
},
{
"chat_id": 3,
"prerequisites": [1, 2],
"recipient": financial_assistant_1,
"message": financial_tasks[2],
"summary_method": "last_msg",
"clear_history": False,
},
{
"chat_id": 4,
"prerequisites": [1, 2, 3],
"recipient": writer,
"message": writing_tasks[0],
"carryover": "I want to include a figure or a table of data in the blogpost.",
"summary_method": "last_msg",
},
]
)
last_chat_id = 4
chat_w_writer = chat_res[last_chat_id]
print(chat_w_writer.chat_history, chat_w_writer.summary, chat_w_writer.cost)
all_res = user.get_chat_results()
writer_res = user.get_chat_results(last_chat_id)
# print(blogpost.summary, insights_and_blogpost)
print(writer_res.summary, writer_res.cost)
print(all_res[1].human_input)
print(all_res[1].summary)
print(all_res[1].chat_history)
print(all_res[2].summary)
if __name__ == "__main__":
test_async_chats()

View File

@@ -24,7 +24,9 @@ Links to notebook examples:
- Running a group chat as an inner-monolgue via the SocietyOfMindAgent - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_society_of_mind.ipynb)
1. **Sequential Multi-Agent Chats**
- Solving Multiple Tasks in a Sequence of Chats Initiated by a Single Agent - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_multi_task_chats.ipynb)
- Async-solving Multiple Tasks in a Sequence of Chats Initiated by a Single Agent - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_multi_task_async_chats.ipynb)
- Solving Multiple Tasks in a Sequence of Chats Initiated by Different Agents - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchats.ipynb)
1. **Applications**
- Automated Chess Game Playing & Chitchatting by GPT-4 Agents - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_chess.ipynb)