mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
* add client interface, response protocol, and move code into openai client class * add ability to register custom client * tidy up code * adding checks and errors, and more unit tests * remove code * fix error msg * add use_docer False in notebook * better error message * add another example to custom model notebook * rename and have register_client take model name too * make Client protocol and remove inheritance * renames * update notebook * add link * rename and more error checking for registered agents * adding message retrieval to client protocol for more flexible response * fix failing openai test * api_type req made model_client_cls requirement * notebook cleanup and added blog * remove raise error if client list is empty - client list will never be empty it will have placeholders * rename Client -> ModelClient * add forgotten file * fix test by fetching internal client * Update autogen/oai/client.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * don't add retrieval function to cache * added placeholder cllient class during initial client init, and rewrote registration * fix spelling * Update autogen/agentchat/conversable_agent.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * type hints, small fixes, docstr comment * fix api type checking --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com>
167 lines
4.9 KiB
Python
167 lines
4.9 KiB
Python
import pytest
|
|
from autogen import OpenAIWrapper
|
|
from autogen.oai import ModelClient
|
|
from typing import Dict
|
|
|
|
try:
|
|
from openai import OpenAI
|
|
except ImportError:
|
|
skip = True
|
|
else:
|
|
skip = False
|
|
|
|
|
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
|
def test_custom_model_client():
|
|
TEST_COST = 20000000
|
|
TEST_CUSTOM_RESPONSE = "This is a custom response."
|
|
TEST_DEVICE = "cpu"
|
|
TEST_LOCAL_MODEL_NAME = "local_model_name"
|
|
TEST_OTHER_PARAMS_VAL = "other_params"
|
|
TEST_MAX_LENGTH = 1000
|
|
|
|
class CustomModel:
|
|
def __init__(self, config: Dict, test_hook):
|
|
self.test_hook = test_hook
|
|
self.device = config["device"]
|
|
self.model = config["model"]
|
|
self.other_params = config["params"]["other_params"]
|
|
self.max_length = config["params"]["max_length"]
|
|
self.test_hook["called"] = True
|
|
# set all params to test hook
|
|
self.test_hook["device"] = self.device
|
|
self.test_hook["model"] = self.model
|
|
self.test_hook["other_params"] = self.other_params
|
|
self.test_hook["max_length"] = self.max_length
|
|
|
|
def create(self, params):
|
|
from types import SimpleNamespace
|
|
|
|
response = SimpleNamespace()
|
|
# need to follow Client.ClientResponseProtocol
|
|
response.choices = []
|
|
choice = SimpleNamespace()
|
|
choice.message = SimpleNamespace()
|
|
choice.message.content = TEST_CUSTOM_RESPONSE
|
|
response.choices.append(choice)
|
|
response.model = self.model
|
|
return response
|
|
|
|
def message_retrieval(self, response):
|
|
return [response.choices[0].message.content]
|
|
|
|
def cost(self, response) -> float:
|
|
"""Calculate the cost of the response."""
|
|
response.cost = TEST_COST
|
|
return TEST_COST
|
|
|
|
@staticmethod
|
|
def get_usage(response) -> Dict:
|
|
return {}
|
|
|
|
config_list = [
|
|
{
|
|
"model": TEST_LOCAL_MODEL_NAME,
|
|
"model_client_cls": "CustomModel",
|
|
"device": TEST_DEVICE,
|
|
"params": {
|
|
"max_length": TEST_MAX_LENGTH,
|
|
"other_params": TEST_OTHER_PARAMS_VAL,
|
|
},
|
|
},
|
|
]
|
|
|
|
test_hook = {"called": False}
|
|
|
|
client = OpenAIWrapper(config_list=config_list)
|
|
client.register_model_client(model_client_cls=CustomModel, test_hook=test_hook)
|
|
|
|
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
|
|
assert response.choices[0].message.content == TEST_CUSTOM_RESPONSE
|
|
assert response.cost == TEST_COST
|
|
|
|
assert test_hook["called"]
|
|
assert test_hook["device"] == TEST_DEVICE
|
|
assert test_hook["model"] == TEST_LOCAL_MODEL_NAME
|
|
assert test_hook["other_params"] == TEST_OTHER_PARAMS_VAL
|
|
assert test_hook["max_length"] == TEST_MAX_LENGTH
|
|
|
|
|
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
|
def test_registering_with_wrong_class_name_raises_error():
|
|
class CustomModel:
|
|
def __init__(self, config: Dict):
|
|
pass
|
|
|
|
def create(self, params):
|
|
return None
|
|
|
|
def message_retrieval(self, response):
|
|
return []
|
|
|
|
def cost(self, response) -> float:
|
|
return 0
|
|
|
|
@staticmethod
|
|
def get_usage(response) -> Dict:
|
|
return {}
|
|
|
|
config_list = [
|
|
{
|
|
"model": "local_model_name",
|
|
"model_client_cls": "CustomModelWrongName",
|
|
},
|
|
]
|
|
client = OpenAIWrapper(config_list=config_list)
|
|
|
|
with pytest.raises(ValueError):
|
|
client.register_model_client(model_client_cls=CustomModel)
|
|
|
|
|
|
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
|
def test_not_all_clients_registered_raises_error():
|
|
class CustomModel:
|
|
def __init__(self, config: Dict):
|
|
pass
|
|
|
|
def create(self, params):
|
|
return None
|
|
|
|
def message_retrieval(self, response):
|
|
return []
|
|
|
|
def cost(self, response) -> float:
|
|
return 0
|
|
|
|
@staticmethod
|
|
def get_usage(response) -> Dict:
|
|
return {}
|
|
|
|
config_list = [
|
|
{
|
|
"model": "local_model_name",
|
|
"model_client_cls": "CustomModel",
|
|
"device": "cpu",
|
|
"params": {
|
|
"max_length": 1000,
|
|
"other_params": "other_params",
|
|
},
|
|
},
|
|
{
|
|
"model": "local_model_name_2",
|
|
"model_client_cls": "CustomModel",
|
|
"device": "cpu",
|
|
"params": {
|
|
"max_length": 1000,
|
|
"other_params": "other_params",
|
|
},
|
|
},
|
|
]
|
|
|
|
client = OpenAIWrapper(config_list=config_list)
|
|
|
|
client.register_model_client(model_client_cls=CustomModel)
|
|
|
|
with pytest.raises(RuntimeError):
|
|
client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
|