updates docstr + fix spelling

This commit is contained in:
Wael Karkoub
2024-04-18 17:31:01 +01:00
parent 2df0f39b00
commit 549b5ac96d
2 changed files with 81 additions and 3 deletions

View File

@@ -77,6 +77,7 @@ class ConversableAgent(LLMAgent):
llm_config: Optional[Union[Dict, Literal[False]]] = None,
default_auto_reply: Union[str, Dict] = "",
description: Optional[str] = None,
messages: Optional[Dict[Agent, List[Dict]]] = None,
):
"""
Args:
@@ -121,7 +122,10 @@ class ConversableAgent(LLMAgent):
When set to None, will use self.DEFAULT_CONFIG, which defaults to False.
default_auto_reply (str or dict): default auto reply when no code execution or llm-based reply is generated.
description (str): a short description of the agent. This description is used by other agents
(e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_message)
(e.g. the GroupChatManager) to decide when to call upon this agent. (Default: system_mess
messages (dict): the previous chat messages that this agent had in the past with other agens. Can be used to
give the agent a memory by providing the chat history. This will allow the agent to resume previous had
conversations. Defaults to an empty chat history.
"""
# we change code_execution_config below and we have to make sure we don't change the input
# in case of UserProxyAgent, without this we could even change the default value {}
@@ -131,7 +135,11 @@ class ConversableAgent(LLMAgent):
self._name = name
# a dictionary of conversations, default value is list
self._oai_messages = defaultdict(list)
if messages is None:
self._oai_messages = defaultdict(list)
else:
self._oai_messages = messages
self._oai_system_message = [{"content": system_message, "role": "system"}]
self._description = description if description is not None else system_message
self._is_termination_msg = (
@@ -1211,7 +1219,6 @@ class ConversableAgent(LLMAgent):
return self._finished_chats
async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
_chat_queue = self._check_chat_queue_for_sender(chat_queue)
self._finished_chats = await a_initiate_chats(_chat_queue)
return self._finished_chats

View File

@@ -1311,6 +1311,77 @@ def test_messages_with_carryover():
assert len(generated_message["content"]) == 2
def test_chat_history():
alice = autogen.ConversableAgent(
"alice",
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is alice speaking.",
)
charlie = autogen.ConversableAgent(
"charlie",
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is charlie speaking.",
)
max_turns = 2
def bob_initiate_chat(agent: ConversableAgent, text: Literal["past", "future"]):
_ = agent.initiate_chat(
alice,
message=f"This is bob from the {text} speaking.",
max_turns=max_turns,
clear_history=False,
silent=True,
)
_ = agent.initiate_chat(
charlie,
message=f"This is bob from the {text} speaking.",
max_turns=max_turns,
clear_history=False,
silent=True,
)
bob = autogen.ConversableAgent(
"bob",
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is bob from the past speaking.",
)
bob_initiate_chat(bob, "past")
context = bob.chat_messages
del bob
# Test agent with chat history
bob = autogen.ConversableAgent(
"bob",
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is bob from the future speaking.",
messages=context,
)
assert bool(bob.chat_messages)
assert bob.chat_messages == context
# two times the max turns due to bob replies
assert len(bob.chat_messages[alice]) == 2 * max_turns
assert len(bob.chat_messages[charlie]) == 2 * max_turns
bob_initiate_chat(bob, "future")
assert len(bob.chat_messages[alice]) == 4 * max_turns
assert len(bob.chat_messages[charlie]) == 4 * max_turns
assert bob.chat_messages[alice][0]["content"] == "This is bob from the past speaking."
assert bob.chat_messages[charlie][0]["content"] == "This is bob from the past speaking."
assert bob.chat_messages[alice][-2]["content"] == "This is bob from the future speaking."
assert bob.chat_messages[charlie][-2]["content"] == "This is bob from the future speaking."
if __name__ == "__main__":
# test_trigger()
# test_context()