mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Fix type issues in openai_utils.py (#2062)
* Fix type issues in openai_utils.py * fix incorrect impl * address comment * add to CI
This commit is contained in:
3
.github/workflows/type-check.yml
vendored
3
.github/workflows/type-check.yml
vendored
@@ -23,4 +23,5 @@ jobs:
|
||||
mypy \
|
||||
autogen/logger \
|
||||
autogen/exception_utils.py \
|
||||
autogen/coding
|
||||
autogen/coding \
|
||||
autogen/oai/openai_utils.py
|
||||
|
||||
@@ -8,15 +8,8 @@ from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
from openai.types.beta.assistant import Assistant
|
||||
|
||||
ERROR = None
|
||||
except ImportError:
|
||||
ERROR = ImportError("Please install openai>=1 to use autogen.OpenAIWrapper.")
|
||||
OpenAI = object
|
||||
Assistant = object
|
||||
from openai import OpenAI
|
||||
from openai.types.beta.assistant import Assistant
|
||||
|
||||
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
|
||||
DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
|
||||
@@ -75,7 +68,7 @@ def get_key(config: Dict[str, Any]) -> str:
|
||||
return json.dumps(config, sort_keys=True)
|
||||
|
||||
|
||||
def is_valid_api_key(api_key: str):
|
||||
def is_valid_api_key(api_key: str) -> bool:
|
||||
"""Determine if input is valid OpenAI API key.
|
||||
|
||||
Args:
|
||||
@@ -89,8 +82,11 @@ def is_valid_api_key(api_key: str):
|
||||
|
||||
|
||||
def get_config_list(
|
||||
api_keys: List, base_urls: Optional[List] = None, api_type: Optional[str] = None, api_version: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
api_keys: List[str],
|
||||
base_urls: Optional[List[str]] = None,
|
||||
api_type: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get a list of configs for OpenAI API client.
|
||||
|
||||
Args:
|
||||
@@ -143,7 +139,7 @@ def config_list_openai_aoai(
|
||||
openai_api_base_file: Optional[str] = "base_openai.txt",
|
||||
aoai_api_base_file: Optional[str] = "base_aoai.txt",
|
||||
exclude: Optional[str] = None,
|
||||
) -> List[Dict]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get a list of configs for OpenAI API client (including Azure or local model deployments that support OpenAI's chat completion API).
|
||||
|
||||
This function constructs configurations by reading API keys and base URLs from environment variables or text files.
|
||||
@@ -250,8 +246,8 @@ def config_list_openai_aoai(
|
||||
else []
|
||||
)
|
||||
# process openai base urls
|
||||
base_urls = os.environ.get("OPENAI_API_BASE", None)
|
||||
base_urls = base_urls if base_urls is None else base_urls.split("\n")
|
||||
base_urls_env_var = os.environ.get("OPENAI_API_BASE", None)
|
||||
base_urls = base_urls_env_var if base_urls_env_var is None else base_urls_env_var.split("\n")
|
||||
openai_config = (
|
||||
get_config_list(
|
||||
# Assuming OpenAI API_KEY in os.environ["OPENAI_API_KEY"]
|
||||
@@ -271,8 +267,8 @@ def config_list_from_models(
|
||||
aoai_api_key_file: Optional[str] = "key_aoai.txt",
|
||||
aoai_api_base_file: Optional[str] = "base_aoai.txt",
|
||||
exclude: Optional[str] = None,
|
||||
model_list: Optional[list] = None,
|
||||
) -> List[Dict]:
|
||||
model_list: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get a list of configs for API calls with models specified in the model list.
|
||||
|
||||
@@ -338,7 +334,7 @@ def config_list_gpt4_gpt35(
|
||||
aoai_api_key_file: Optional[str] = "key_aoai.txt",
|
||||
aoai_api_base_file: Optional[str] = "base_aoai.txt",
|
||||
exclude: Optional[str] = None,
|
||||
) -> List[Dict]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get a list of configs for 'gpt-4' followed by 'gpt-3.5-turbo' API calls.
|
||||
|
||||
Args:
|
||||
@@ -361,7 +357,10 @@ def config_list_gpt4_gpt35(
|
||||
)
|
||||
|
||||
|
||||
def filter_config(config_list, filter_dict):
|
||||
def filter_config(
|
||||
config_list: List[Dict[str, Any]],
|
||||
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
This function filters `config_list` by checking each configuration dictionary against the
|
||||
criteria specified in `filter_dict`. A configuration dictionary is retained if for every
|
||||
@@ -426,7 +425,7 @@ def filter_config(config_list, filter_dict):
|
||||
dictionaries that do not have that key will also be considered a match.
|
||||
"""
|
||||
|
||||
def _satisfies(config_value, acceptable_values):
|
||||
def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
|
||||
if isinstance(config_value, list):
|
||||
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
|
||||
else:
|
||||
@@ -445,7 +444,7 @@ def config_list_from_json(
|
||||
env_or_file: str,
|
||||
file_location: Optional[str] = "",
|
||||
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]] = None,
|
||||
) -> List[Dict]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieves a list of API configurations from a JSON stored in an environment variable or a file.
|
||||
|
||||
@@ -497,15 +496,22 @@ def config_list_from_json(
|
||||
else:
|
||||
# The environment variable does not exist.
|
||||
# So, `env_or_file` is a filename. We should use the file location.
|
||||
config_list_path = os.path.join(file_location, env_or_file)
|
||||
if file_location is not None:
|
||||
config_list_path = os.path.join(file_location, env_or_file)
|
||||
else:
|
||||
config_list_path = env_or_file
|
||||
|
||||
with open(config_list_path) as json_file:
|
||||
config_list = json.load(json_file)
|
||||
return filter_config(config_list, filter_dict)
|
||||
|
||||
|
||||
def get_config(
|
||||
api_key: str, base_url: Optional[str] = None, api_type: Optional[str] = None, api_version: Optional[str] = None
|
||||
) -> Dict:
|
||||
api_key: Optional[str],
|
||||
base_url: Optional[str] = None,
|
||||
api_type: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Constructs a configuration dictionary for a single model with the provided API configurations.
|
||||
|
||||
@@ -544,7 +550,9 @@ def get_config(
|
||||
|
||||
|
||||
def config_list_from_dotenv(
|
||||
dotenv_file_path: Optional[str] = None, model_api_key_map: Optional[dict] = None, filter_dict: Optional[dict] = None
|
||||
dotenv_file_path: Optional[str] = None,
|
||||
model_api_key_map: Optional[Dict[str, Any]] = None,
|
||||
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]] = None,
|
||||
) -> List[Dict[str, Union[str, Set[str]]]]:
|
||||
"""
|
||||
Load API configurations from a specified .env file or environment variables and construct a list of configurations.
|
||||
@@ -582,9 +590,10 @@ def config_list_from_dotenv(
|
||||
else:
|
||||
logging.warning(f"The specified .env file {dotenv_path} does not exist.")
|
||||
else:
|
||||
dotenv_path = find_dotenv()
|
||||
if not dotenv_path:
|
||||
dotenv_path_str = find_dotenv()
|
||||
if not dotenv_path_str:
|
||||
logging.warning("No .env file found. Loading configurations from environment variables.")
|
||||
dotenv_path = Path(dotenv_path_str)
|
||||
load_dotenv(dotenv_path)
|
||||
|
||||
# Ensure the model_api_key_map is not None to prevent TypeErrors during key assignment.
|
||||
@@ -647,8 +656,6 @@ def retrieve_assistants_by_name(client: OpenAI, name: str) -> List[Assistant]:
|
||||
"""
|
||||
Return the assistants with the given name from OAI assistant API
|
||||
"""
|
||||
if ERROR:
|
||||
raise ERROR
|
||||
assistants = client.beta.assistants.list()
|
||||
candidate_assistants = []
|
||||
for assistant in assistants.data:
|
||||
|
||||
Reference in New Issue
Block a user