mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-26 23:10:23 -05:00
Allow initiate_chat without passing message (#1244)
* allow initiate_chat without passing message * test human input * assert called * Add missing method a_generate_init_message * fix tests * add back skipif * Update test/agentchat/test_async_get_human_input.py --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
@@ -679,6 +679,7 @@ class ConversableAgent(Agent):
|
||||
silent (bool or None): (Experimental) whether to print the messages for this conversation.
|
||||
**context: any context information.
|
||||
"message" needs to be provided if the `generate_init_message` method is not overridden.
|
||||
Otherwise, input() will be called to get the initial message.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if any async reply functions are registered and not ignored in sync chat.
|
||||
@@ -707,9 +708,10 @@ class ConversableAgent(Agent):
|
||||
silent (bool or None): (Experimental) whether to print the messages for this conversation.
|
||||
**context: any context information.
|
||||
"message" needs to be provided if the `generate_init_message` method is not overridden.
|
||||
Otherwise, input() will be called to get the initial message.
|
||||
"""
|
||||
self._prepare_chat(recipient, clear_history)
|
||||
await self.a_send(self.generate_init_message(**context), recipient, silent=silent)
|
||||
await self.a_send(await self.a_generate_init_message(**context), recipient, silent=silent)
|
||||
|
||||
def reset(self):
|
||||
"""Reset the agent."""
|
||||
@@ -1583,7 +1585,24 @@ class ConversableAgent(Agent):
|
||||
|
||||
Args:
|
||||
**context: any context information, and "message" parameter needs to be provided.
|
||||
If message is not given, prompt for it via input()
|
||||
"""
|
||||
if "message" not in context:
|
||||
context["message"] = self.get_human_input(">")
|
||||
return context["message"]
|
||||
|
||||
async def a_generate_init_message(self, **context) -> Union[str, Dict]:
|
||||
"""Generate the initial message for the agent.
|
||||
|
||||
Override this function to customize the initial message based on user's request.
|
||||
If not overridden, "message" needs to be provided in the context.
|
||||
|
||||
Args:
|
||||
**context: any context information, and "message" parameter needs to be provided.
|
||||
If message is not given, prompt for it via input()
|
||||
"""
|
||||
if "message" not in context:
|
||||
context["message"] = await self.a_get_human_input(">")
|
||||
return context["message"]
|
||||
|
||||
def register_function(self, function_map: Dict[str, Callable]):
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import autogen
|
||||
import pytest
|
||||
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from conftest import skip_openai # noqa: E402
|
||||
@@ -25,20 +27,17 @@ async def test_async_get_human_input():
|
||||
assistant = autogen.AssistantAgent(
|
||||
name="assistant",
|
||||
max_consecutive_auto_reply=2,
|
||||
llm_config={"timeout": 600, "cache_seed": 41, "config_list": config_list, "temperature": 0},
|
||||
llm_config={"seed": 41, "config_list": config_list, "temperature": 0},
|
||||
)
|
||||
|
||||
user_proxy = autogen.UserProxyAgent(name="user", human_input_mode="ALWAYS", code_execution_config=False)
|
||||
|
||||
async def custom_a_get_human_input(prompt):
|
||||
return "This is a test"
|
||||
|
||||
user_proxy.a_get_human_input = custom_a_get_human_input
|
||||
user_proxy.a_get_human_input = AsyncMock(return_value="This is a test")
|
||||
|
||||
user_proxy.register_reply([autogen.Agent, None], autogen.ConversableAgent.a_check_termination_and_human_reply)
|
||||
|
||||
await user_proxy.a_initiate_chat(assistant, clear_history=True, message="Hello.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_async_get_human_input()
|
||||
# Test without message
|
||||
await user_proxy.a_initiate_chat(assistant, clear_history=True)
|
||||
# Assert that custom a_get_human_input was called at least once
|
||||
user_proxy.a_get_human_input.assert_called()
|
||||
|
||||
46
test/agentchat/test_human_input.py
Normal file
46
test/agentchat/test_human_input.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import autogen
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from conftest import skip_openai # noqa: E402
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError:
|
||||
skip = True
|
||||
else:
|
||||
skip = False or skip_openai
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip")
|
||||
def test_get_human_input():
|
||||
config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, KEY_LOC)
|
||||
|
||||
# create an AssistantAgent instance named "assistant"
|
||||
assistant = autogen.AssistantAgent(
|
||||
name="assistant",
|
||||
max_consecutive_auto_reply=2,
|
||||
llm_config={"timeout": 600, "cache_seed": 41, "config_list": config_list, "temperature": 0},
|
||||
)
|
||||
|
||||
user_proxy = autogen.UserProxyAgent(name="user", human_input_mode="ALWAYS", code_execution_config=False)
|
||||
|
||||
# Use MagicMock to create a mock get_human_input function
|
||||
user_proxy.get_human_input = MagicMock(return_value="This is a test")
|
||||
|
||||
user_proxy.register_reply([autogen.Agent, None], autogen.ConversableAgent.a_check_termination_and_human_reply)
|
||||
|
||||
user_proxy.initiate_chat(assistant, clear_history=True, message="Hello.")
|
||||
# Test without supplying messages parameter
|
||||
user_proxy.initiate_chat(assistant, clear_history=True)
|
||||
|
||||
# Assert that custom_a_get_human_input was called at least once
|
||||
user_proxy.get_human_input.assert_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_get_human_input()
|
||||
Reference in New Issue
Block a user