Allow initiate_chat without passing message (#1244)

* allow initiate_chat without passing message

* test human input

* assert called

* Add missing method a_generate_init_message

* fix tests

* add back skipif

* Update test/agentchat/test_async_get_human_input.py

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
bitnom
2024-01-18 22:46:20 -05:00
committed by GitHub
parent 97296105eb
commit e97b6395af
3 changed files with 76 additions and 12 deletions

View File

@@ -679,6 +679,7 @@ class ConversableAgent(Agent):
silent (bool or None): (Experimental) whether to print the messages for this conversation.
**context: any context information.
"message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
Raises:
RuntimeError: if any async reply functions are registered and not ignored in sync chat.
@@ -707,9 +708,10 @@ class ConversableAgent(Agent):
silent (bool or None): (Experimental) whether to print the messages for this conversation.
**context: any context information.
"message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
"""
self._prepare_chat(recipient, clear_history)
await self.a_send(self.generate_init_message(**context), recipient, silent=silent)
await self.a_send(await self.a_generate_init_message(**context), recipient, silent=silent)
def reset(self):
"""Reset the agent."""
@@ -1583,7 +1585,24 @@ class ConversableAgent(Agent):
Args:
**context: any context information, and "message" parameter needs to be provided.
If message is not given, prompt for it via input()
"""
if "message" not in context:
context["message"] = self.get_human_input(">")
return context["message"]
async def a_generate_init_message(self, **context) -> Union[str, Dict]:
"""Generate the initial message for the agent.
Override this function to customize the initial message based on user's request.
If not overridden, "message" needs to be provided in the context.
Args:
**context: any context information, and "message" parameter needs to be provided.
If message is not given, prompt for it via input()
"""
if "message" not in context:
context["message"] = await self.a_get_human_input(">")
return context["message"]
def register_function(self, function_map: Dict[str, Callable]):

View File

@@ -1,9 +1,11 @@
import asyncio
import os
import sys
from unittest.mock import AsyncMock
import autogen
import pytest
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from conftest import skip_openai # noqa: E402
@@ -25,20 +27,17 @@ async def test_async_get_human_input():
assistant = autogen.AssistantAgent(
name="assistant",
max_consecutive_auto_reply=2,
llm_config={"timeout": 600, "cache_seed": 41, "config_list": config_list, "temperature": 0},
llm_config={"seed": 41, "config_list": config_list, "temperature": 0},
)
user_proxy = autogen.UserProxyAgent(name="user", human_input_mode="ALWAYS", code_execution_config=False)
async def custom_a_get_human_input(prompt):
return "This is a test"
user_proxy.a_get_human_input = custom_a_get_human_input
user_proxy.a_get_human_input = AsyncMock(return_value="This is a test")
user_proxy.register_reply([autogen.Agent, None], autogen.ConversableAgent.a_check_termination_and_human_reply)
await user_proxy.a_initiate_chat(assistant, clear_history=True, message="Hello.")
if __name__ == "__main__":
test_async_get_human_input()
# Test without message
await user_proxy.a_initiate_chat(assistant, clear_history=True)
# Assert that custom a_get_human_input was called at least once
user_proxy.a_get_human_input.assert_called()

View File

@@ -0,0 +1,46 @@
import autogen
import pytest
from unittest.mock import MagicMock
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from conftest import skip_openai # noqa: E402
try:
from openai import OpenAI
except ImportError:
skip = True
else:
skip = False or skip_openai
@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip")
def test_get_human_input():
config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, KEY_LOC)
# create an AssistantAgent instance named "assistant"
assistant = autogen.AssistantAgent(
name="assistant",
max_consecutive_auto_reply=2,
llm_config={"timeout": 600, "cache_seed": 41, "config_list": config_list, "temperature": 0},
)
user_proxy = autogen.UserProxyAgent(name="user", human_input_mode="ALWAYS", code_execution_config=False)
# Use MagicMock to create a mock get_human_input function
user_proxy.get_human_input = MagicMock(return_value="This is a test")
user_proxy.register_reply([autogen.Agent, None], autogen.ConversableAgent.a_check_termination_and_human_reply)
user_proxy.initiate_chat(assistant, clear_history=True, message="Hello.")
# Test without supplying messages parameter
user_proxy.initiate_chat(assistant, clear_history=True)
# Assert that custom_a_get_human_input was called at least once
user_proxy.get_human_input.assert_called()
if __name__ == "__main__":
test_get_human_input()