mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 15:17:59 -05:00
fix(agent): Unbreak LLM status check on start-up
Fixes #7508 - Amend `app/configurator.py:check_model(..)` to check multiple models at once and save duplicate API calls - Amend `MultiProvider.get_available_providers()` to verify availability by fetching models and handle failure
This commit is contained in:
@@ -51,8 +51,9 @@ async def apply_overrides_to_config(
|
||||
raise click.UsageError("--continuous-limit can only be used with --continuous")
|
||||
|
||||
# Check availability of configured LLMs; fallback to other LLM if unavailable
|
||||
config.fast_llm = await check_model(config.fast_llm, "fast_llm")
|
||||
config.smart_llm = await check_model(config.smart_llm, "smart_llm")
|
||||
config.fast_llm, config.smart_llm = await check_models(
|
||||
(config.fast_llm, "fast_llm"), (config.smart_llm, "smart_llm")
|
||||
)
|
||||
|
||||
if skip_reprompt:
|
||||
config.skip_reprompt = True
|
||||
@@ -61,17 +62,22 @@ async def apply_overrides_to_config(
|
||||
config.skip_news = True
|
||||
|
||||
|
||||
async def check_model(
|
||||
model_name: ModelName, model_type: Literal["smart_llm", "fast_llm"]
|
||||
) -> ModelName:
|
||||
async def check_models(
|
||||
*models: tuple[ModelName, Literal["smart_llm", "fast_llm"]]
|
||||
) -> tuple[ModelName, ...]:
|
||||
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
|
||||
multi_provider = MultiProvider()
|
||||
models = await multi_provider.get_available_chat_models()
|
||||
available_models = await multi_provider.get_available_chat_models()
|
||||
|
||||
if any(model_name == m.name for m in models):
|
||||
return model_name
|
||||
checked_models: list[ModelName] = []
|
||||
for model, model_type in models:
|
||||
if any(model == m.name for m in available_models):
|
||||
checked_models.append(model)
|
||||
else:
|
||||
logger.warning(
|
||||
f"You don't have access to {model}. "
|
||||
f"Setting {model_type} to {GPT_3_MODEL}."
|
||||
)
|
||||
checked_models.append(GPT_3_MODEL)
|
||||
|
||||
logger.warning(
|
||||
f"You don't have access to {model_name}. Setting {model_type} to {GPT_3_MODEL}."
|
||||
)
|
||||
return GPT_3_MODEL
|
||||
return tuple(checked_models)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Iterator, Optional, Sequence, TypeVar, get_args
|
||||
from typing import Any, AsyncIterator, Callable, Optional, Sequence, TypeVar, get_args
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
@@ -68,7 +68,7 @@ class MultiProvider(BaseChatModelProvider[ModelName, ModelProviderSettings]):
|
||||
|
||||
async def get_available_chat_models(self) -> Sequence[ChatModelInfo[ModelName]]:
|
||||
models = []
|
||||
for provider in self.get_available_providers():
|
||||
async for provider in self.get_available_providers():
|
||||
models.extend(await provider.get_available_chat_models())
|
||||
return models
|
||||
|
||||
@@ -120,14 +120,18 @@ class MultiProvider(BaseChatModelProvider[ModelName, ModelProviderSettings]):
|
||||
model_info = CHAT_MODELS[model]
|
||||
return self._get_provider(model_info.provider_name)
|
||||
|
||||
def get_available_providers(self) -> Iterator[ChatModelProvider]:
|
||||
async def get_available_providers(self) -> AsyncIterator[ChatModelProvider]:
|
||||
for provider_name in ModelProviderName:
|
||||
self._logger.debug(f"Checking if {provider_name} is available...")
|
||||
self._logger.debug(f"Checking if provider {provider_name} is available...")
|
||||
try:
|
||||
yield self._get_provider(provider_name)
|
||||
self._logger.debug(f"{provider_name} is available!")
|
||||
provider = self._get_provider(provider_name)
|
||||
await provider.get_available_models() # check connection
|
||||
yield provider
|
||||
self._logger.debug(f"Provider '{provider_name}' is available!")
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Provider '{provider_name}' is failing: {e}")
|
||||
|
||||
def _get_provider(self, provider_name: ModelProviderName) -> ChatModelProvider:
|
||||
_provider = self._provider_instances.get(provider_name)
|
||||
|
||||
Reference in New Issue
Block a user