Files
autogen/test/oai/test_custom_client.py
olgavrou 00417edb5a Custom Model Client support (#1345)
* add client interface, response protocol, and move code into openai client class

* add ability to register custom client

* tidy up code

* adding checks and errors, and more unit tests

* remove code

* fix error msg

* add use_docer False in notebook

* better error message

* add another example to custom model notebook

* rename and have register_client take model name too

* make Client protocol and remove inheritance

* renames

* update notebook

* add link

* rename and more error checking for registered agents

* adding message retrieval to client protocol for more flexible response

* fix failing openai test

* api_type req made model_client_cls requirement

* notebook cleanup and added blog

* remove raise error if client list is empty - client list will never be empty it will have placeholders

* rename Client -> ModelClient

* add forgotten file

* fix test by fetching internal client

* Update autogen/oai/client.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* don't add retrieval function to cache

* added placeholder cllient class during initial client init, and rewrote registration

* fix spelling

* Update autogen/agentchat/conversable_agent.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* type hints, small fixes, docstr comment

* fix api type checking

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
2024-02-02 17:21:19 +00:00

167 lines
4.9 KiB
Python

import pytest
from autogen import OpenAIWrapper
from autogen.oai import ModelClient
from typing import Dict
try:
from openai import OpenAI
except ImportError:
skip = True
else:
skip = False
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_custom_model_client():
TEST_COST = 20000000
TEST_CUSTOM_RESPONSE = "This is a custom response."
TEST_DEVICE = "cpu"
TEST_LOCAL_MODEL_NAME = "local_model_name"
TEST_OTHER_PARAMS_VAL = "other_params"
TEST_MAX_LENGTH = 1000
class CustomModel:
def __init__(self, config: Dict, test_hook):
self.test_hook = test_hook
self.device = config["device"]
self.model = config["model"]
self.other_params = config["params"]["other_params"]
self.max_length = config["params"]["max_length"]
self.test_hook["called"] = True
# set all params to test hook
self.test_hook["device"] = self.device
self.test_hook["model"] = self.model
self.test_hook["other_params"] = self.other_params
self.test_hook["max_length"] = self.max_length
def create(self, params):
from types import SimpleNamespace
response = SimpleNamespace()
# need to follow Client.ClientResponseProtocol
response.choices = []
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = TEST_CUSTOM_RESPONSE
response.choices.append(choice)
response.model = self.model
return response
def message_retrieval(self, response):
return [response.choices[0].message.content]
def cost(self, response) -> float:
"""Calculate the cost of the response."""
response.cost = TEST_COST
return TEST_COST
@staticmethod
def get_usage(response) -> Dict:
return {}
config_list = [
{
"model": TEST_LOCAL_MODEL_NAME,
"model_client_cls": "CustomModel",
"device": TEST_DEVICE,
"params": {
"max_length": TEST_MAX_LENGTH,
"other_params": TEST_OTHER_PARAMS_VAL,
},
},
]
test_hook = {"called": False}
client = OpenAIWrapper(config_list=config_list)
client.register_model_client(model_client_cls=CustomModel, test_hook=test_hook)
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
assert response.choices[0].message.content == TEST_CUSTOM_RESPONSE
assert response.cost == TEST_COST
assert test_hook["called"]
assert test_hook["device"] == TEST_DEVICE
assert test_hook["model"] == TEST_LOCAL_MODEL_NAME
assert test_hook["other_params"] == TEST_OTHER_PARAMS_VAL
assert test_hook["max_length"] == TEST_MAX_LENGTH
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_registering_with_wrong_class_name_raises_error():
class CustomModel:
def __init__(self, config: Dict):
pass
def create(self, params):
return None
def message_retrieval(self, response):
return []
def cost(self, response) -> float:
return 0
@staticmethod
def get_usage(response) -> Dict:
return {}
config_list = [
{
"model": "local_model_name",
"model_client_cls": "CustomModelWrongName",
},
]
client = OpenAIWrapper(config_list=config_list)
with pytest.raises(ValueError):
client.register_model_client(model_client_cls=CustomModel)
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_not_all_clients_registered_raises_error():
class CustomModel:
def __init__(self, config: Dict):
pass
def create(self, params):
return None
def message_retrieval(self, response):
return []
def cost(self, response) -> float:
return 0
@staticmethod
def get_usage(response) -> Dict:
return {}
config_list = [
{
"model": "local_model_name",
"model_client_cls": "CustomModel",
"device": "cpu",
"params": {
"max_length": 1000,
"other_params": "other_params",
},
},
{
"model": "local_model_name_2",
"model_client_cls": "CustomModel",
"device": "cpu",
"params": {
"max_length": 1000,
"other_params": "other_params",
},
},
]
client = OpenAIWrapper(config_list=config_list)
client.register_model_client(model_client_cls=CustomModel)
with pytest.raises(RuntimeError):
client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)