mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Restore TeachableAgent tests (#761)
* Update chat_with_teachable_agent.py to v2. * Update agentchat_teachability.ipynb to v2. * Add test of teachability accuracy. * Update installation instructions. * Add to contrib tests. * pre-commit fixes * Apply reviewer suggestions to test workflows.
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
from autogen import UserProxyAgent, config_list_from_json
|
||||
from autogen.agentchat.contrib.teachable_agent import TeachableAgent
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from test_assistant_agent import OAI_CONFIG_LIST, KEY_LOC # noqa: E402
|
||||
|
||||
|
||||
try:
|
||||
from termcolor import colored
|
||||
@@ -12,10 +18,13 @@ except ImportError:
|
||||
|
||||
verbosity = 0 # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
|
||||
recall_threshold = 1.5 # Higher numbers allow more (but less relevant) memos to be recalled.
|
||||
use_cache = False # If True, cached LLM calls will be skipped and responses pulled from cache. False exposes LLM non-determinism.
|
||||
cache_seed = None # Use an int to seed the response cache. Use None to disable caching.
|
||||
|
||||
# Specify the model to use. GPT-3.5 is less reliable than GPT-4 at learning from user input.
|
||||
# filter_dict = {"model": ["gpt-4-0613"]}
|
||||
# filter_dict = {"model": ["gpt-3.5-turbo-0613"]}
|
||||
filter_dict = {"model": ["gpt-4"]}
|
||||
# filter_dict = {"model": ["gpt-35-turbo-16k", "gpt-3.5-turbo-16k"]}
|
||||
|
||||
|
||||
def create_teachable_agent(reset_db=False):
|
||||
@@ -23,10 +32,10 @@ def create_teachable_agent(reset_db=False):
|
||||
# Load LLM inference endpoints from an env variable or a file
|
||||
# See https://microsoft.github.io/autogen/docs/FAQ#set-your-api-endpoints
|
||||
# and OAI_CONFIG_LIST_sample
|
||||
config_list = config_list_from_json(env_or_file="OAI_CONFIG_LIST", filter_dict=filter_dict)
|
||||
config_list = config_list_from_json(env_or_file=OAI_CONFIG_LIST, filter_dict=filter_dict, file_location=KEY_LOC)
|
||||
teachable_agent = TeachableAgent(
|
||||
name="teachableagent",
|
||||
llm_config={"config_list": config_list, "timeout": 120, "use_cache": use_cache},
|
||||
llm_config={"config_list": config_list, "timeout": 120, "cache_seed": cache_seed},
|
||||
teach_config={
|
||||
"verbosity": verbosity,
|
||||
"reset_db": reset_db,
|
||||
@@ -1,3 +1,11 @@
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
from autogen import ConversableAgent, config_list_from_json
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from test_assistant_agent import OAI_CONFIG_LIST, KEY_LOC # noqa: E402
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
from autogen.agentchat.contrib.teachable_agent import TeachableAgent
|
||||
@@ -6,11 +14,6 @@ except ImportError:
|
||||
else:
|
||||
skip = False
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
from autogen import ConversableAgent, config_list_from_json
|
||||
from test_assistant_agent import OAI_CONFIG_LIST, KEY_LOC
|
||||
|
||||
try:
|
||||
from termcolor import colored
|
||||
except ImportError:
|
||||
@@ -25,8 +28,7 @@ skill_verbosity = 3 # 0 for basic info, 1 to add memory operations, 2 for analy
|
||||
|
||||
assert_on_error = False # GPT-4 nearly always succeeds on these unit tests, but GPT-3.5 is a bit less reliable.
|
||||
recall_threshold = 1.5 # Higher numbers allow more (but less relevant) memos to be recalled.
|
||||
cache_seed = None
|
||||
# If int, cached LLM calls will be skipped and responses pulled from cache. None exposes LLM non-determinism.
|
||||
cache_seed = None # Use an int to seed the response cache. Use None to disable caching.
|
||||
|
||||
# Specify the model to use by uncommenting one of the following lines.
|
||||
# filter_dict={"model": ["gpt-4-0613"]}
|
||||
@@ -139,10 +141,10 @@ def use_task_advice_pair_phrasing():
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip or not sys.version.startswith("3.11"),
|
||||
reason="do not run if dependency is not installed or py!=3.11",
|
||||
skip,
|
||||
reason="do not run if dependency is not installed",
|
||||
)
|
||||
def test_all():
|
||||
def test_teachability_code_paths():
|
||||
"""Runs this file's unit tests."""
|
||||
total_num_errors, total_num_tests = 0, 0
|
||||
|
||||
@@ -169,6 +171,49 @@ def test_all():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip,
|
||||
reason="do not run if dependency is not installed",
|
||||
)
|
||||
def test_teachability_accuracy():
|
||||
"""A very cheap and fast test of teachability accuracy."""
|
||||
print(colored("\nTEST TEACHABILITY ACCURACY", "light_cyan"))
|
||||
|
||||
num_trials = 10 # The expected probability of failure is about 0.3 on each trial.
|
||||
for trial in range(num_trials):
|
||||
teachable_agent = create_teachable_agent(
|
||||
reset_db=True, verbosity=0
|
||||
) # For a clean test, clear the agent's memory.
|
||||
user = ConversableAgent("user", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
|
||||
|
||||
# Prepopulate memory with a few arbitrary memos, just to make retrieval less trivial.
|
||||
teachable_agent.prepopulate_db()
|
||||
|
||||
# Tell the teachable agent something it wouldn't already know.
|
||||
user.initiate_chat(recipient=teachable_agent, message="My favorite color is teal.")
|
||||
|
||||
# Let the teachable agent remember things that should be learned from this chat.
|
||||
teachable_agent.learn_from_user_feedback()
|
||||
|
||||
# Now start a new chat to clear the context, and ask the teachable agent about the new information.
|
||||
print(colored("\nSTARTING A NEW CHAT WITH EMPTY CONTEXT", "light_cyan"))
|
||||
user.initiate_chat(recipient=teachable_agent, message="What's my favorite color?")
|
||||
num_errors = check_agent_response(teachable_agent, user, "teal")
|
||||
|
||||
print(colored(f"\nTRIAL {trial + 1} OF {num_trials} FINISHED", "light_cyan"))
|
||||
|
||||
# Wrap up.
|
||||
teachable_agent.close_db()
|
||||
|
||||
# Exit on the first success.
|
||||
if num_errors == 0:
|
||||
return
|
||||
|
||||
# All trials failed.
|
||||
assert False, "test_teachability_accuracy() failed on all {} trials.".format(num_trials)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Runs this file's unit tests from the command line."""
|
||||
test_all()
|
||||
test_teachability_code_paths()
|
||||
test_teachability_accuracy()
|
||||
Reference in New Issue
Block a user