Added the ability to add tags to the OAI_CONFIG_LIST, and filter (#1226)

* Added the ability to add tags to the OAI_CONFIG_LIST, and filter on them.

* Update openai_utils.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
afourney
2024-01-14 20:47:19 -08:00
committed by GitHub
parent 63a35e79f8
commit e6325a402a
3 changed files with 59 additions and 3 deletions

View File

@@ -52,7 +52,7 @@ class OpenAIWrapper:
"""A wrapper class for openai client."""
cache_path_root: str = ".cache"
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"}
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version", "tags"}
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
total_usage_summary: Optional[Dict[str, Any]] = None
actual_usage_summary: Optional[Dict[str, Any]] = None

View File

@@ -356,6 +356,11 @@ def filter_config(config_list, filter_dict):
filter_dict (dict): A dictionary representing the filter criteria, where each key is a
field name to check within the configuration dictionaries, and the
corresponding value is a list of acceptable values for that field.
If the configuration's field's value is not a list, then a match occurs
when it is found in the list of acceptable values. If the configuration's
field's value is a list, then a match occurs if there is a non-empty
intersection with the acceptable values.
Returns:
list of dict: A list of configuration dictionaries that meet all the criteria specified
@@ -368,6 +373,7 @@ def filter_config(config_list, filter_dict):
{'model': 'gpt-3.5-turbo'},
{'model': 'gpt-4'},
{'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
]
# Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
@@ -382,6 +388,19 @@ def filter_config(config_list, filter_dict):
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
# Define a filter to select a given tag
filter_criteria = {
'tags': ['gpt35_turbo'],
}
# Apply the filter to the configuration list
filtered_configs = filter_config(configs, filter_criteria)
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
```
Note:
@@ -391,9 +410,18 @@ def filter_config(config_list, filter_dict):
- If the list of acceptable values for a key in `filter_dict` includes None, then configuration
dictionaries that do not have that key will also be considered a match.
"""
def _satisfies(config_value, acceptable_values):
if isinstance(config_value, list):
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
else:
return config_value in acceptable_values
if filter_dict:
config_list = [
config for config in config_list if all(config.get(key) in value for key, value in filter_dict.items())
config
for config in config_list
if all(_satisfies(config.get(key), value) for key, value in filter_dict.items())
]
return config_list

View File

@@ -8,7 +8,7 @@ from unittest.mock import patch
import pytest
import autogen # noqa: E402
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, filter_config
# Example environment variables
ENV_VARS = {
@@ -48,6 +48,7 @@ JSON_SAMPLE = """
},
{
"model": "gpt-35-turbo-v0301",
"tags": ["gpt-3.5-turbo", "gpt35_turbo"],
"api_key": "111113fc7e8a46419bfac511bb301111",
"base_url": "https://1111.openai.azure.com",
"api_type": "azure",
@@ -342,5 +343,32 @@ def test_get_config_list():
assert len(config_list_with_empty_key) == 2, "The config_list should exclude configurations with empty api_keys."
def test_tags():
config_list = json.loads(JSON_SAMPLE)
target_list = filter_config(config_list, {"model": ["gpt-35-turbo-v0301"]})
assert len(target_list) == 1
list_1 = filter_config(config_list, {"tags": ["gpt35_turbo"]})
assert len(list_1) == 1
assert list_1[0] == target_list[0]
list_2 = filter_config(config_list, {"tags": ["gpt-3.5-turbo"]})
assert len(list_2) == 1
assert list_2[0] == target_list[0]
list_3 = filter_config(config_list, {"tags": ["gpt-3.5-turbo", "gpt35_turbo"]})
assert len(list_3) == 1
assert list_3[0] == target_list[0]
# Will still match because there's a non-empty intersection
list_4 = filter_config(config_list, {"tags": ["gpt-3.5-turbo", "does_not_exist"]})
assert len(list_4) == 1
assert list_4[0] == target_list[0]
list_5 = filter_config(config_list, {"tags": ["does_not_exist"]})
assert len(list_5) == 0
if __name__ == "__main__":
pytest.main()