Fix type issues in openai_utils.py (#2062)

* Fix type issues in openai_utils.py

* fix incorrect impl

* address comment

* add to CI
This commit is contained in:
Jack Gerrits
2024-03-19 15:35:19 -04:00
committed by GitHub
parent e23bdfb38e
commit 38b64b6ade
2 changed files with 38 additions and 30 deletions

View File

@@ -23,4 +23,5 @@ jobs:
mypy \
autogen/logger \
autogen/exception_utils.py \
autogen/coding
autogen/coding \
autogen/oai/openai_utils.py

View File

@@ -8,15 +8,8 @@ from typing import Any, Dict, List, Optional, Set, Union
from dotenv import find_dotenv, load_dotenv
try:
from openai import OpenAI
from openai.types.beta.assistant import Assistant
ERROR = None
except ImportError:
ERROR = ImportError("Please install openai>=1 to use autogen.OpenAIWrapper.")
OpenAI = object
Assistant = object
from openai import OpenAI
from openai.types.beta.assistant import Assistant
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
@@ -75,7 +68,7 @@ def get_key(config: Dict[str, Any]) -> str:
return json.dumps(config, sort_keys=True)
def is_valid_api_key(api_key: str):
def is_valid_api_key(api_key: str) -> bool:
"""Determine if input is valid OpenAI API key.
Args:
@@ -89,8 +82,11 @@ def is_valid_api_key(api_key: str):
def get_config_list(
api_keys: List, base_urls: Optional[List] = None, api_type: Optional[str] = None, api_version: Optional[str] = None
) -> List[Dict]:
api_keys: List[str],
base_urls: Optional[List[str]] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Get a list of configs for OpenAI API client.
Args:
@@ -143,7 +139,7 @@ def config_list_openai_aoai(
openai_api_base_file: Optional[str] = "base_openai.txt",
aoai_api_base_file: Optional[str] = "base_aoai.txt",
exclude: Optional[str] = None,
) -> List[Dict]:
) -> List[Dict[str, Any]]:
"""Get a list of configs for OpenAI API client (including Azure or local model deployments that support OpenAI's chat completion API).
This function constructs configurations by reading API keys and base URLs from environment variables or text files.
@@ -250,8 +246,8 @@ def config_list_openai_aoai(
else []
)
# process openai base urls
base_urls = os.environ.get("OPENAI_API_BASE", None)
base_urls = base_urls if base_urls is None else base_urls.split("\n")
base_urls_env_var = os.environ.get("OPENAI_API_BASE", None)
base_urls = base_urls_env_var if base_urls_env_var is None else base_urls_env_var.split("\n")
openai_config = (
get_config_list(
# Assuming OpenAI API_KEY in os.environ["OPENAI_API_KEY"]
@@ -271,8 +267,8 @@ def config_list_from_models(
aoai_api_key_file: Optional[str] = "key_aoai.txt",
aoai_api_base_file: Optional[str] = "base_aoai.txt",
exclude: Optional[str] = None,
model_list: Optional[list] = None,
) -> List[Dict]:
model_list: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""
Get a list of configs for API calls with models specified in the model list.
@@ -338,7 +334,7 @@ def config_list_gpt4_gpt35(
aoai_api_key_file: Optional[str] = "key_aoai.txt",
aoai_api_base_file: Optional[str] = "base_aoai.txt",
exclude: Optional[str] = None,
) -> List[Dict]:
) -> List[Dict[str, Any]]:
"""Get a list of configs for 'gpt-4' followed by 'gpt-3.5-turbo' API calls.
Args:
@@ -361,7 +357,10 @@ def config_list_gpt4_gpt35(
)
def filter_config(config_list, filter_dict):
def filter_config(
config_list: List[Dict[str, Any]],
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]],
) -> List[Dict[str, Any]]:
"""
This function filters `config_list` by checking each configuration dictionary against the
criteria specified in `filter_dict`. A configuration dictionary is retained if for every
@@ -426,7 +425,7 @@ def filter_config(config_list, filter_dict):
dictionaries that do not have that key will also be considered a match.
"""
def _satisfies(config_value, acceptable_values):
def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
if isinstance(config_value, list):
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
else:
@@ -445,7 +444,7 @@ def config_list_from_json(
env_or_file: str,
file_location: Optional[str] = "",
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]] = None,
) -> List[Dict]:
) -> List[Dict[str, Any]]:
"""
Retrieves a list of API configurations from a JSON stored in an environment variable or a file.
@@ -497,15 +496,22 @@ def config_list_from_json(
else:
# The environment variable does not exist.
# So, `env_or_file` is a filename. We should use the file location.
config_list_path = os.path.join(file_location, env_or_file)
if file_location is not None:
config_list_path = os.path.join(file_location, env_or_file)
else:
config_list_path = env_or_file
with open(config_list_path) as json_file:
config_list = json.load(json_file)
return filter_config(config_list, filter_dict)
def get_config(
api_key: str, base_url: Optional[str] = None, api_type: Optional[str] = None, api_version: Optional[str] = None
) -> Dict:
api_key: Optional[str],
base_url: Optional[str] = None,
api_type: Optional[str] = None,
api_version: Optional[str] = None,
) -> Dict[str, Any]:
"""
Constructs a configuration dictionary for a single model with the provided API configurations.
@@ -544,7 +550,9 @@ def get_config(
def config_list_from_dotenv(
dotenv_file_path: Optional[str] = None, model_api_key_map: Optional[dict] = None, filter_dict: Optional[dict] = None
dotenv_file_path: Optional[str] = None,
model_api_key_map: Optional[Dict[str, Any]] = None,
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]] = None,
) -> List[Dict[str, Union[str, Set[str]]]]:
"""
Load API configurations from a specified .env file or environment variables and construct a list of configurations.
@@ -582,9 +590,10 @@ def config_list_from_dotenv(
else:
logging.warning(f"The specified .env file {dotenv_path} does not exist.")
else:
dotenv_path = find_dotenv()
if not dotenv_path:
dotenv_path_str = find_dotenv()
if not dotenv_path_str:
logging.warning("No .env file found. Loading configurations from environment variables.")
dotenv_path = Path(dotenv_path_str)
load_dotenv(dotenv_path)
# Ensure the model_api_key_map is not None to prevent TypeErrors during key assignment.
@@ -647,8 +656,6 @@ def retrieve_assistants_by_name(client: OpenAI, name: str) -> List[Assistant]:
"""
Return the assistants with the given name from OAI assistant API
"""
if ERROR:
raise ERROR
assistants = client.beta.assistants.list()
candidate_assistants = []
for assistant in assistants.data: