support azure assistant api (#1616)

* support azure assistant api

* try to add azure testing

* improve testing

* fix testing

* fix code

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Ian
2024-02-15 13:29:08 +08:00
committed by GitHub
parent cff9ca9a11
commit b270a2e467
2 changed files with 49 additions and 21 deletions

View File

@@ -53,9 +53,16 @@ class GPTAssistantAgent(ConversableAgent):
- Other kwargs: Except verbose, others are passed directly to ConversableAgent.
"""
# Use AutoGen OpenAIWrapper to create a client
oai_wrapper = OpenAIWrapper(**llm_config)
openai_client_cfg = None
model_name = "gpt-4-1106-preview"
if llm_config and llm_config.get("config_list") is not None and len(llm_config["config_list"]) > 0:
openai_client_cfg = llm_config["config_list"][0].copy()
model_name = openai_client_cfg.pop("model", "gpt-4-1106-preview")
oai_wrapper = OpenAIWrapper(**openai_client_cfg)
if len(oai_wrapper._clients) > 1:
logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.")
self._openai_client = oai_wrapper._clients[0]._oai_client
openai_assistant_id = llm_config.get("assistant_id", None)
if openai_assistant_id is None:
@@ -79,7 +86,7 @@ class GPTAssistantAgent(ConversableAgent):
name=name,
instructions=instructions,
tools=llm_config.get("tools", []),
model=llm_config.get("model", "gpt-4-1106-preview"),
model=model_name,
file_ids=llm_config.get("file_ids", []),
)
else:

View File

@@ -23,9 +23,14 @@ else:
skip = False or skip_openai
if not skip:
config_list = autogen.config_list_from_json(
openai_config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"api_type": ["openai"]}
)
aoai_config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"api_type": ["azure"], "api_version": ["2024-02-15-preview"]},
)
@pytest.mark.skipif(
@@ -33,7 +38,8 @@ if not skip:
reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
)
def test_config_list() -> None:
assert len(config_list) > 0
assert len(openai_config_list) > 0
assert len(aoai_config_list) > 0
@pytest.mark.skipif(
@@ -41,6 +47,11 @@ def test_config_list() -> None:
reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
)
def test_gpt_assistant_chat() -> None:
for gpt_config in [openai_config_list, aoai_config_list]:
_test_gpt_assistant_chat(gpt_config)
def _test_gpt_assistant_chat(gpt_config) -> None:
ossinsight_api_schema = {
"name": "ossinsight_data_api",
"parameters": {
@@ -64,7 +75,7 @@ def test_gpt_assistant_chat() -> None:
name = f"For test_gpt_assistant_chat {uuid.uuid4()}"
analyst = GPTAssistantAgent(
name=name,
llm_config={"tools": [{"type": "function", "function": ossinsight_api_schema}], "config_list": config_list},
llm_config={"tools": [{"type": "function", "function": ossinsight_api_schema}], "config_list": gpt_config},
instructions="Hello, Open Source Project Analyst. You'll conduct comprehensive evaluations of open source projects or organizations on the GitHub platform",
)
try:
@@ -90,7 +101,7 @@ def test_gpt_assistant_chat() -> None:
# check the question asked
ask_ossinsight_mock.assert_called_once()
question_asked = ask_ossinsight_mock.call_args[0][0].lower()
for word in "microsoft autogen stars github".split(" "):
for word in "microsoft autogen star github".split(" "):
assert word in question_asked
# check the answer
@@ -108,6 +119,11 @@ def test_gpt_assistant_chat() -> None:
reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
)
def test_get_assistant_instructions() -> None:
for gpt_config in [openai_config_list, aoai_config_list]:
_test_get_assistant_instructions(gpt_config)
def _test_get_assistant_instructions(gpt_config) -> None:
"""
Test function to create a new GPTAssistantAgent, set its instructions, retrieve the instructions,
and assert that the retrieved instructions match the set instructions.
@@ -117,7 +133,7 @@ def test_get_assistant_instructions() -> None:
name,
instructions="This is a test",
llm_config={
"config_list": config_list,
"config_list": gpt_config,
},
)
@@ -132,6 +148,11 @@ def test_get_assistant_instructions() -> None:
reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
)
def test_gpt_assistant_instructions_overwrite() -> None:
for gpt_config in [openai_config_list, aoai_config_list]:
_test_gpt_assistant_instructions_overwrite(gpt_config)
def _test_gpt_assistant_instructions_overwrite(gpt_config) -> None:
"""
Test that the instructions of a GPTAssistantAgent can be overwritten or not depending on the value of the
`overwrite_instructions` parameter when creating a new assistant with the same ID.
@@ -151,7 +172,7 @@ def test_gpt_assistant_instructions_overwrite() -> None:
name,
instructions=instructions1,
llm_config={
"config_list": config_list,
"config_list": gpt_config,
},
)
@@ -161,7 +182,7 @@ def test_gpt_assistant_instructions_overwrite() -> None:
name,
instructions=instructions2,
llm_config={
"config_list": config_list,
"config_list": gpt_config,
"assistant_id": assistant_id,
},
overwrite_instructions=True,
@@ -191,7 +212,7 @@ def test_gpt_assistant_existing_no_instructions() -> None:
name,
instructions=instructions,
llm_config={
"config_list": config_list,
"config_list": openai_config_list,
},
)
@@ -202,7 +223,7 @@ def test_gpt_assistant_existing_no_instructions() -> None:
assistant = GPTAssistantAgent(
name,
llm_config={
"config_list": config_list,
"config_list": openai_config_list,
"assistant_id": assistant_id,
},
)
@@ -225,7 +246,7 @@ def test_get_assistant_files() -> None:
and assert that the retrieved instructions match the set instructions.
"""
current_file_path = os.path.abspath(__file__)
openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client
openai_client = OpenAIWrapper(config_list=openai_config_list)._clients[0]._oai_client
file = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
name = f"For test_get_assistant_files {uuid.uuid4()}"
@@ -233,7 +254,7 @@ def test_get_assistant_files() -> None:
name,
instructions="This is a test",
llm_config={
"config_list": config_list,
"config_list": openai_config_list,
"tools": [{"type": "retrieval"}],
"file_ids": [file.id],
},
@@ -274,7 +295,7 @@ def test_assistant_retrieval() -> None:
"description": "This is a test function 2",
}
openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client
openai_client = OpenAIWrapper(config_list=openai_config_list)._clients[0]._oai_client
current_file_path = os.path.abspath(__file__)
file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
@@ -289,7 +310,7 @@ def test_assistant_retrieval() -> None:
{"type": "code_interpreter"},
],
"file_ids": [file_1.id, file_2.id],
"config_list": config_list,
"config_list": openai_config_list,
}
name = f"For test_assistant_retrieval {uuid.uuid4()}"
@@ -350,7 +371,7 @@ def test_assistant_mismatch_retrieval() -> None:
"description": "This is a test function 3",
}
openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client
openai_client = OpenAIWrapper(config_list=openai_config_list)._clients[0]._oai_client
current_file_path = os.path.abspath(__file__)
file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
file_2 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
@@ -364,7 +385,7 @@ def test_assistant_mismatch_retrieval() -> None:
{"type": "code_interpreter"},
],
"file_ids": [file_1.id, file_2.id],
"config_list": config_list,
"config_list": openai_config_list,
}
name = f"For test_assistant_retrieval {uuid.uuid4()}"
@@ -400,7 +421,7 @@ def test_assistant_mismatch_retrieval() -> None:
{"type": "function", "function": function_1_schema},
],
"file_ids": [file_2.id],
"config_list": config_list,
"config_list": openai_config_list,
}
assistant_file_ids_mismatch = GPTAssistantAgent(
name,
@@ -418,7 +439,7 @@ def test_assistant_mismatch_retrieval() -> None:
{"type": "function", "function": function_3_schema},
],
"file_ids": [file_2.id, file_1.id],
"config_list": config_list,
"config_list": openai_config_list,
}
assistant_tools_mistaching = GPTAssistantAgent(
name,
@@ -536,7 +557,7 @@ def test_gpt_assistant_tools_overwrite() -> None:
assistant_org = GPTAssistantAgent(
name,
llm_config={
"config_list": config_list,
"config_list": openai_config_list,
"tools": original_tools,
},
)
@@ -548,7 +569,7 @@ def test_gpt_assistant_tools_overwrite() -> None:
assistant = GPTAssistantAgent(
name,
llm_config={
"config_list": config_list,
"config_list": openai_config_list,
"assistant_id": assistant_id,
"tools": new_tools,
},