mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Refactor GPTAssistantAgent (#632)
* Refactor GPTAssistantAgent constructor to handle instructions and overwrite_instructions flag - Ensure that `system_message` is always consistent with `instructions` - Ensure provided instructions are always used - Add option to permanently modify the instructions of the assistant * Improve default behavior * Add a test; add method to delete assistant * Add a new test for overwriting instructions * Add test case for when no instructions are given for existing assistant * Add pytest markers to test_gpt_assistant.py * add test in workflow * update * fix test_client_stream * comment out test_hierarchy_ --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: kevin666aa <yrwu000627@gmail.com>
This commit is contained in:
41
.github/workflows/contrib-openai.yml
vendored
41
.github/workflows/contrib-openai.yml
vendored
@@ -97,3 +97,44 @@ jobs:
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
GPTAssistantAgent:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.11"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
environment: openai1
|
||||
steps:
|
||||
# checkout to pr branch
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install packages and dependencies
|
||||
run: |
|
||||
docker --version
|
||||
python -m pip install --upgrade pip wheel
|
||||
pip install -e .
|
||||
python -c "import autogen"
|
||||
pip install coverage pytest-asyncio
|
||||
- name: Install packages for test when needed
|
||||
run: |
|
||||
pip install docker
|
||||
- name: Coverage
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
|
||||
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
|
||||
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
|
||||
run: |
|
||||
coverage run -a -m pytest test/agentchat/contrib/test_gpt_assistant.py
|
||||
coverage xml
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
|
||||
26
.github/workflows/contrib-tests.yml
vendored
26
.github/workflows/contrib-tests.yml
vendored
@@ -82,3 +82,29 @@ jobs:
|
||||
if: matrix.python-version != '3.10'
|
||||
run: |
|
||||
pytest test/agentchat/contrib/test_compressible_agent.py
|
||||
|
||||
GPTAssistantAgent:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install packages and dependencies for all tests
|
||||
run: |
|
||||
python -m pip install --upgrade pip wheel
|
||||
pip install pytest
|
||||
- name: Install packages and dependencies for GPTAssistantAgent
|
||||
run: |
|
||||
pip install -e .
|
||||
pip uninstall -y openai
|
||||
- name: Test GPTAssistantAgent
|
||||
if: matrix.python-version != '3.10'
|
||||
run: |
|
||||
pytest test/agentchat/contrib/test_gpt_assistant.py
|
||||
|
||||
@@ -7,6 +7,7 @@ import logging
|
||||
from autogen import OpenAIWrapper
|
||||
from autogen.agentchat.agent import Agent
|
||||
from autogen.agentchat.assistant_agent import ConversableAgent
|
||||
from autogen.agentchat.assistant_agent import AssistantAgent
|
||||
from typing import Dict, Optional, Union, List, Tuple, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -21,20 +22,70 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
def __init__(
|
||||
self,
|
||||
name="GPT Assistant",
|
||||
instructions: Optional[str] = "You are a helpful GPT Assistant.",
|
||||
instructions: Optional[str] = None,
|
||||
llm_config: Optional[Union[Dict, bool]] = None,
|
||||
overwrite_instructions: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
name (str): name of the agent.
|
||||
instructions (str): instructions for the OpenAI assistant configuration.
|
||||
When instructions is not None, the system message of the agent will be
|
||||
set to the provided instructions and used in the assistant run, irrespective
|
||||
of the overwrite_instructions flag. But when instructions is None,
|
||||
and the assistant does not exist, the system message will be set to
|
||||
AssistantAgent.DEFAULT_SYSTEM_MESSAGE. If the assistant exists, the
|
||||
system message will be set to the existing assistant instructions.
|
||||
llm_config (dict or False): llm inference configuration.
|
||||
- assistant_id: ID of the assistant to use. If None, a new assistant will be created.
|
||||
- model: Model to use for the assistant (gpt-4-1106-preview, gpt-3.5-turbo-1106).
|
||||
- check_every_ms: check thread run status interval
|
||||
- tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
|
||||
or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
|
||||
- file_ids: files used by retrieval in run
|
||||
overwrite_instructions (bool): whether to overwrite the instructions of an existing assistant.
|
||||
"""
|
||||
# Use AutoGen OpenAIWrapper to create a client
|
||||
oai_wrapper = OpenAIWrapper(**llm_config)
|
||||
if len(oai_wrapper._clients) > 1:
|
||||
logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.")
|
||||
self._openai_client = oai_wrapper._clients[0]
|
||||
openai_assistant_id = llm_config.get("assistant_id", None)
|
||||
if openai_assistant_id is None:
|
||||
# create a new assistant
|
||||
if instructions is None:
|
||||
logger.warning(
|
||||
"No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE."
|
||||
)
|
||||
instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE
|
||||
self._openai_assistant = self._openai_client.beta.assistants.create(
|
||||
name=name,
|
||||
instructions=instructions,
|
||||
tools=llm_config.get("tools", []),
|
||||
model=llm_config.get("model", "gpt-4-1106-preview"),
|
||||
)
|
||||
else:
|
||||
# retrieve an existing assistant
|
||||
self._openai_assistant = self._openai_client.beta.assistants.retrieve(openai_assistant_id)
|
||||
# if no instructions are provided, set the instructions to the existing instructions
|
||||
if instructions is None:
|
||||
logger.warning(
|
||||
"No instructions were provided for given assistant. Using existing instructions from assistant API."
|
||||
)
|
||||
instructions = self.get_assistant_instructions()
|
||||
elif overwrite_instructions is True:
|
||||
logger.warning(
|
||||
"overwrite_instructions is True. Provided instructions will be used and will modify the assistant in the API"
|
||||
)
|
||||
self._openai_assistant = self._openai_client.beta.assistants.update(
|
||||
assistant_id=openai_assistant_id,
|
||||
instructions=instructions,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"overwrite_instructions is False. Provided instructions will be used without permanently modifying the assistant in the API."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name=name,
|
||||
system_message=instructions,
|
||||
@@ -42,25 +93,6 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
llm_config=llm_config,
|
||||
)
|
||||
|
||||
# Use AutoGen OpenAIWrapper to create a client
|
||||
oai_wrapper = OpenAIWrapper(**self.llm_config)
|
||||
if len(oai_wrapper._clients) > 1:
|
||||
logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.")
|
||||
self._openai_client = oai_wrapper._clients[0]
|
||||
|
||||
openai_assistant_id = llm_config.get("assistant_id", None)
|
||||
if openai_assistant_id is None:
|
||||
# create a new assistant
|
||||
self._openai_assistant = self._openai_client.beta.assistants.create(
|
||||
name=name,
|
||||
instructions=instructions,
|
||||
tools=self.llm_config.get("tools", []),
|
||||
model=self.llm_config.get("model", "gpt-4-1106-preview"),
|
||||
)
|
||||
else:
|
||||
# retrieve an existing assistant
|
||||
self._openai_assistant = self._openai_client.beta.assistants.retrieve(openai_assistant_id)
|
||||
|
||||
# lazly create thread
|
||||
self._openai_threads = {}
|
||||
self._unread_index = defaultdict(int)
|
||||
@@ -107,6 +139,8 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
run = self._openai_client.beta.threads.runs.create(
|
||||
thread_id=assistant_thread.id,
|
||||
assistant_id=self._openai_assistant.id,
|
||||
# pass the latest system message as instructions
|
||||
instructions=self.system_message,
|
||||
)
|
||||
|
||||
run_response_messages = self._get_run_response(assistant_thread, run)
|
||||
@@ -300,3 +334,16 @@ class GPTAssistantAgent(ConversableAgent):
|
||||
def oai_threads(self) -> Dict[Agent, Any]:
|
||||
"""Return the threads of the agent."""
|
||||
return self._openai_threads
|
||||
|
||||
@property
|
||||
def assistant_id(self):
|
||||
"""Return the assistant id"""
|
||||
return self._openai_assistant.id
|
||||
|
||||
def get_assistant_instructions(self):
|
||||
"""Return the assistant instructions from OAI assistant API"""
|
||||
return self._openai_assistant.instructions
|
||||
|
||||
def delete_assistant(self):
|
||||
"""Delete the assistant from OAI assistant API"""
|
||||
self._openai_client.beta.assistants.delete(self.assistant_id)
|
||||
|
||||
@@ -38,9 +38,10 @@ def test_gpt_assistant_chat():
|
||||
"description": "This is an API endpoint allowing users (analysts) to input question about GitHub in text format to retrieve the realted and structured data.",
|
||||
}
|
||||
|
||||
config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, file_location=KEY_LOC)
|
||||
analyst = GPTAssistantAgent(
|
||||
name="Open_Source_Project_Analyst",
|
||||
llm_config={"tools": [{"type": "function", "function": ossinsight_api_schema}]},
|
||||
llm_config={"tools": [{"type": "function", "function": ossinsight_api_schema}], "config_list": config_list},
|
||||
instructions="Hello, Open Source Project Analyst. You'll conduct comprehensive evaluations of open source projects or organizations on the GitHub platform",
|
||||
)
|
||||
analyst.register_function(
|
||||
@@ -62,5 +63,114 @@ def test_gpt_assistant_chat():
|
||||
assert len(analyst._openai_threads) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform in ["darwin", "win32"] or skip_test,
|
||||
reason="do not run on MacOS or windows or dependency is not installed",
|
||||
)
|
||||
def test_get_assistant_instructions():
|
||||
"""
|
||||
Test function to create a new GPTAssistantAgent, set its instructions, retrieve the instructions,
|
||||
and assert that the retrieved instructions match the set instructions.
|
||||
"""
|
||||
|
||||
config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, file_location=KEY_LOC)
|
||||
assistant = GPTAssistantAgent(
|
||||
"assistant",
|
||||
instructions="This is a test",
|
||||
llm_config={
|
||||
"config_list": config_list,
|
||||
},
|
||||
)
|
||||
|
||||
instruction_match = assistant.get_assistant_instructions() == "This is a test"
|
||||
assistant.delete_assistant()
|
||||
|
||||
assert instruction_match is True
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform in ["darwin", "win32"] or skip_test,
|
||||
reason="do not run on MacOS or windows or dependency is not installed",
|
||||
)
|
||||
def test_gpt_assistant_instructions_overwrite():
|
||||
"""
|
||||
Test that the instructions of a GPTAssistantAgent can be overwritten or not depending on the value of the
|
||||
`overwrite_instructions` parameter when creating a new assistant with the same ID.
|
||||
|
||||
Steps:
|
||||
1. Create a new GPTAssistantAgent with some instructions.
|
||||
2. Get the ID of the assistant.
|
||||
3. Create a new GPTAssistantAgent with the same ID but different instructions and `overwrite_instructions=True`.
|
||||
4. Check that the instructions of the assistant have been overwritten with the new ones.
|
||||
"""
|
||||
|
||||
instructions1 = "This is a test #1"
|
||||
instructions2 = "This is a test #2"
|
||||
|
||||
config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, file_location=KEY_LOC)
|
||||
assistant = GPTAssistantAgent(
|
||||
"assistant",
|
||||
instructions=instructions1,
|
||||
llm_config={
|
||||
"config_list": config_list,
|
||||
},
|
||||
)
|
||||
|
||||
assistant_id = assistant.assistant_id
|
||||
assistant = GPTAssistantAgent(
|
||||
"assistant",
|
||||
instructions=instructions2,
|
||||
llm_config={
|
||||
"config_list": config_list,
|
||||
"assistant_id": assistant_id,
|
||||
},
|
||||
overwrite_instructions=True,
|
||||
)
|
||||
|
||||
instruction_match = assistant.get_assistant_instructions() == instructions2
|
||||
assistant.delete_assistant()
|
||||
|
||||
assert instruction_match is True
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform in ["darwin", "win32"] or skip_test,
|
||||
reason="do not run on MacOS or windows or dependency is not installed",
|
||||
)
|
||||
def test_gpt_assistant_existing_no_instructions():
|
||||
"""
|
||||
Test function to check if the GPTAssistantAgent can retrieve instructions for an existing assistant
|
||||
even if the assistant was created with no instructions initially.
|
||||
"""
|
||||
instructions = "This is a test #1"
|
||||
|
||||
config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, file_location=KEY_LOC)
|
||||
assistant = GPTAssistantAgent(
|
||||
"assistant",
|
||||
instructions=instructions,
|
||||
llm_config={
|
||||
"config_list": config_list,
|
||||
},
|
||||
)
|
||||
|
||||
assistant_id = assistant.assistant_id
|
||||
|
||||
# create a new assistant with the same ID but no instructions
|
||||
assistant = GPTAssistantAgent(
|
||||
"assistant",
|
||||
llm_config={
|
||||
"config_list": config_list,
|
||||
"assistant_id": assistant_id,
|
||||
},
|
||||
)
|
||||
|
||||
instruction_match = assistant.get_assistant_instructions() == instructions
|
||||
assistant.delete_assistant()
|
||||
assert instruction_match is True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gpt_assistant_chat()
|
||||
test_get_assistant_instructions()
|
||||
test_gpt_assistant_instructions_overwrite()
|
||||
test_gpt_assistant_existing_no_instructions()
|
||||
|
||||
@@ -18,7 +18,7 @@ def test_aoai_chat_completion_stream():
|
||||
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
|
||||
)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(messages=[{"role": "user", "content": "2+2="}], seed=None, stream=True)
|
||||
response = client.create(messages=[{"role": "user", "content": "2+2="}], stream=True)
|
||||
print(response)
|
||||
print(client.extract_text_or_function_call(response))
|
||||
|
||||
@@ -31,7 +31,7 @@ def test_chat_completion_stream():
|
||||
filter_dict={"model": ["gpt-3.5-turbo"]},
|
||||
)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(messages=[{"role": "user", "content": "1+1="}], seed=None, stream=True)
|
||||
response = client.create(messages=[{"role": "user", "content": "1+1="}], stream=True)
|
||||
print(response)
|
||||
print(client.extract_text_or_function_call(response))
|
||||
|
||||
@@ -63,7 +63,6 @@ def test_chat_functions_stream():
|
||||
response = client.create(
|
||||
messages=[{"role": "user", "content": "What's the weather like today in San Francisco?"}],
|
||||
functions=functions,
|
||||
seed=None,
|
||||
stream=True,
|
||||
)
|
||||
print(response)
|
||||
@@ -74,7 +73,7 @@ def test_chat_functions_stream():
|
||||
def test_completion_stream():
|
||||
config_list = config_list_openai_aoai(KEY_LOC)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", seed=None, stream=True)
|
||||
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
|
||||
print(response)
|
||||
print(client.extract_text_or_function_call(response))
|
||||
|
||||
|
||||
@@ -84,12 +84,12 @@ def _test_oai_chatgpt_gpt4(save=False):
|
||||
run_notebook("oai_chatgpt_gpt4.ipynb", save=save)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip or not sys.version.startswith("3.10"),
|
||||
reason="do not run if openai is not installed or py!=3.10",
|
||||
)
|
||||
def test_hierarchy_flow_using_select_speaker(save=False):
|
||||
run_notebook("agentchat_hierarchy_flow_using_select_speaker.ipynb", save=save)
|
||||
# @pytest.mark.skipif(
|
||||
# skip or not sys.version.startswith("3.10"),
|
||||
# reason="do not run if openai is not installed or py!=3.10",
|
||||
# )
|
||||
# def test_hierarchy_flow_using_select_speaker(save=False):
|
||||
# run_notebook("agentchat_hierarchy_flow_using_select_speaker.ipynb", save=save)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user