mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Added the ability to add tags to the OAI_CONFIG_LIST, and filter (#1226)
* Added the ability to add tags to the OAI_CONFIG_LIST, and filter on them. * Update openai_utils.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
@@ -52,7 +52,7 @@ class OpenAIWrapper:
|
||||
"""A wrapper class for openai client."""
|
||||
|
||||
cache_path_root: str = ".cache"
|
||||
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"}
|
||||
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version", "tags"}
|
||||
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
|
||||
total_usage_summary: Optional[Dict[str, Any]] = None
|
||||
actual_usage_summary: Optional[Dict[str, Any]] = None
|
||||
|
||||
@@ -356,6 +356,11 @@ def filter_config(config_list, filter_dict):
|
||||
filter_dict (dict): A dictionary representing the filter criteria, where each key is a
|
||||
field name to check within the configuration dictionaries, and the
|
||||
corresponding value is a list of acceptable values for that field.
|
||||
If the configuration's field's value is not a list, then a match occurs
|
||||
when it is found in the list of acceptable values. If the configuration's
|
||||
field's value is a list, then a match occurs if there is a non-empty
|
||||
intersection with the acceptable values.
|
||||
|
||||
|
||||
Returns:
|
||||
list of dict: A list of configuration dictionaries that meet all the criteria specified
|
||||
@@ -368,6 +373,7 @@ def filter_config(config_list, filter_dict):
|
||||
{'model': 'gpt-3.5-turbo'},
|
||||
{'model': 'gpt-4'},
|
||||
{'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
|
||||
{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
|
||||
]
|
||||
|
||||
# Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
|
||||
@@ -382,6 +388,19 @@ def filter_config(config_list, filter_dict):
|
||||
|
||||
# The resulting `filtered_configs` will be:
|
||||
# [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
|
||||
|
||||
|
||||
# Define a filter to select a given tag
|
||||
filter_criteria = {
|
||||
'tags': ['gpt35_turbo'],
|
||||
}
|
||||
|
||||
# Apply the filter to the configuration list
|
||||
filtered_configs = filter_config(configs, filter_criteria)
|
||||
|
||||
# The resulting `filtered_configs` will be:
|
||||
# [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
|
||||
|
||||
```
|
||||
|
||||
Note:
|
||||
@@ -391,9 +410,18 @@ def filter_config(config_list, filter_dict):
|
||||
- If the list of acceptable values for a key in `filter_dict` includes None, then configuration
|
||||
dictionaries that do not have that key will also be considered a match.
|
||||
"""
|
||||
|
||||
def _satisfies(config_value, acceptable_values):
|
||||
if isinstance(config_value, list):
|
||||
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
|
||||
else:
|
||||
return config_value in acceptable_values
|
||||
|
||||
if filter_dict:
|
||||
config_list = [
|
||||
config for config in config_list if all(config.get(key) in value for key, value in filter_dict.items())
|
||||
config
|
||||
for config in config_list
|
||||
if all(_satisfies(config.get(key), value) for key, value in filter_dict.items())
|
||||
]
|
||||
return config_list
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
import autogen # noqa: E402
|
||||
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION
|
||||
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, filter_config
|
||||
|
||||
# Example environment variables
|
||||
ENV_VARS = {
|
||||
@@ -48,6 +48,7 @@ JSON_SAMPLE = """
|
||||
},
|
||||
{
|
||||
"model": "gpt-35-turbo-v0301",
|
||||
"tags": ["gpt-3.5-turbo", "gpt35_turbo"],
|
||||
"api_key": "111113fc7e8a46419bfac511bb301111",
|
||||
"base_url": "https://1111.openai.azure.com",
|
||||
"api_type": "azure",
|
||||
@@ -342,5 +343,32 @@ def test_get_config_list():
|
||||
assert len(config_list_with_empty_key) == 2, "The config_list should exclude configurations with empty api_keys."
|
||||
|
||||
|
||||
def test_tags():
|
||||
config_list = json.loads(JSON_SAMPLE)
|
||||
|
||||
target_list = filter_config(config_list, {"model": ["gpt-35-turbo-v0301"]})
|
||||
assert len(target_list) == 1
|
||||
|
||||
list_1 = filter_config(config_list, {"tags": ["gpt35_turbo"]})
|
||||
assert len(list_1) == 1
|
||||
assert list_1[0] == target_list[0]
|
||||
|
||||
list_2 = filter_config(config_list, {"tags": ["gpt-3.5-turbo"]})
|
||||
assert len(list_2) == 1
|
||||
assert list_2[0] == target_list[0]
|
||||
|
||||
list_3 = filter_config(config_list, {"tags": ["gpt-3.5-turbo", "gpt35_turbo"]})
|
||||
assert len(list_3) == 1
|
||||
assert list_3[0] == target_list[0]
|
||||
|
||||
# Will still match because there's a non-empty intersection
|
||||
list_4 = filter_config(config_list, {"tags": ["gpt-3.5-turbo", "does_not_exist"]})
|
||||
assert len(list_4) == 1
|
||||
assert list_4[0] == target_list[0]
|
||||
|
||||
list_5 = filter_config(config_list, {"tags": ["does_not_exist"]})
|
||||
assert len(list_5) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
|
||||
Reference in New Issue
Block a user