mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Add fallback logic for disabled LLM models
Introduces fallback selection for disabled LLM models in llm_call, preferring enabled models from the same provider. Updates registry utilities to support fallback lookup, model info retrieval, and validation of all known model slugs. Schema utilities now keep all known models in validation enums while showing only enabled models in UI options.
This commit is contained in:
@@ -337,93 +337,88 @@ async def llm_call(
|
||||
- prompt_tokens: The number of tokens used in the prompt.
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
# Get model metadata - try cache first, then fallback to async lookup
|
||||
# Also check if the model is enabled
|
||||
try:
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
# Get model metadata and check if enabled - with fallback support
|
||||
# The model we'll actually use (may differ if original is disabled)
|
||||
model_to_use = llm_model.value
|
||||
|
||||
# Check if model is enabled - get from registry
|
||||
from backend.data.llm_registry import _dynamic_models
|
||||
# Check if model is in registry and if it's enabled
|
||||
from backend.data.llm_registry import (
|
||||
get_fallback_model_for_disabled,
|
||||
get_model_info,
|
||||
)
|
||||
|
||||
if llm_model.value in _dynamic_models:
|
||||
model_info = _dynamic_models[llm_model.value]
|
||||
if not model_info.is_enabled:
|
||||
raise ValueError(f"LLM model '{llm_model.value}' is disabled.")
|
||||
except ValueError as e:
|
||||
# Re-raise if it's our disabled model error
|
||||
if "is disabled" in str(e):
|
||||
raise
|
||||
# Model not in cache - try refreshing the registry once if we have DB access
|
||||
import logging
|
||||
model_info = get_model_info(llm_model.value)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"Model %s not found in registry cache",
|
||||
llm_model.value,
|
||||
)
|
||||
|
||||
# Try refreshing the registry if we have database access
|
||||
from backend.data.db import is_connected
|
||||
|
||||
if is_connected():
|
||||
try:
|
||||
logger.info(
|
||||
"Refreshing LLM registry and retrying lookup for %s",
|
||||
llm_model.value,
|
||||
)
|
||||
await llm_registry.refresh_llm_registry()
|
||||
# Try again after refresh
|
||||
try:
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
|
||||
# Check if model is enabled after refresh
|
||||
from backend.data.llm_registry import _dynamic_models
|
||||
|
||||
if llm_model.value in _dynamic_models:
|
||||
model_info = _dynamic_models[llm_model.value]
|
||||
if not model_info.is_enabled:
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' is disabled. "
|
||||
"Please enable it in the LLM registry via the admin UI to use this model."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Successfully loaded model %s metadata after registry refresh",
|
||||
llm_model.value,
|
||||
)
|
||||
except ValueError as ve:
|
||||
# Re-raise if it's our disabled model error
|
||||
if "is disabled" in str(ve):
|
||||
raise
|
||||
# Still not found after refresh
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry after refresh. "
|
||||
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
|
||||
)
|
||||
except Exception as refresh_exc:
|
||||
logger.error(
|
||||
"Failed to refresh LLM registry: %s", refresh_exc, exc_info=True
|
||||
)
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
|
||||
"Please ensure the model is added to the LLM registry via the admin UI."
|
||||
) from refresh_exc
|
||||
if model_info and not model_info.is_enabled:
|
||||
# Model is disabled - try to find a fallback from the same provider
|
||||
fallback = get_fallback_model_for_disabled(llm_model.value)
|
||||
if fallback:
|
||||
logger.warning(
|
||||
f"Model '{llm_model.value}' is disabled. Using fallback model '{fallback.slug}' from the same provider ({fallback.metadata.provider})."
|
||||
)
|
||||
model_to_use = fallback.slug
|
||||
# Use fallback model's metadata
|
||||
provider = fallback.metadata.provider
|
||||
context_window = fallback.metadata.context_window
|
||||
model_max_output = fallback.metadata.max_output_tokens or int(2**15)
|
||||
else:
|
||||
# No DB access (e.g., in executor without direct DB connection)
|
||||
# The registry should have been loaded on startup
|
||||
# No fallback available - raise error
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry cache. "
|
||||
"The registry may need to be refreshed. Please contact support or try again later."
|
||||
) from e
|
||||
f"LLM model '{llm_model.value}' is disabled and no fallback model "
|
||||
f"from the same provider is available. Please enable the model or "
|
||||
f"select a different model in the block configuration."
|
||||
)
|
||||
else:
|
||||
# Model is enabled or not in registry (legacy/static model)
|
||||
try:
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
except ValueError:
|
||||
# Model not in cache - try refreshing the registry once if we have DB access
|
||||
logger.warning(f"Model {llm_model.value} not found in registry cache")
|
||||
|
||||
# Try refreshing the registry if we have database access
|
||||
from backend.data.db import is_connected
|
||||
|
||||
if is_connected():
|
||||
try:
|
||||
logger.info(
|
||||
f"Refreshing LLM registry and retrying lookup for {llm_model.value}"
|
||||
)
|
||||
await llm_registry.refresh_llm_registry()
|
||||
# Try again after refresh
|
||||
try:
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
logger.info(
|
||||
f"Successfully loaded model {llm_model.value} metadata after registry refresh"
|
||||
)
|
||||
except ValueError:
|
||||
# Still not found after refresh
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry after refresh. "
|
||||
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
|
||||
)
|
||||
except Exception as refresh_exc:
|
||||
logger.error(f"Failed to refresh LLM registry: {refresh_exc}")
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
|
||||
"Please ensure the model is added to the LLM registry via the admin UI."
|
||||
) from refresh_exc
|
||||
else:
|
||||
# No DB access (e.g., in executor without direct DB connection)
|
||||
# The registry should have been loaded on startup
|
||||
raise ValueError(
|
||||
f"LLM model '{llm_model.value}' not found in registry cache. "
|
||||
"The registry may need to be refreshed. Please contact support or try again later."
|
||||
)
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
prompt = compress_prompt(
|
||||
messages=prompt,
|
||||
target_tokens=llm_model.context_window // 2,
|
||||
target_tokens=context_window // 2,
|
||||
lossy_ok=True,
|
||||
)
|
||||
|
||||
@@ -447,7 +442,7 @@ async def llm_call(
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
@@ -494,7 +489,7 @@ async def llm_call(
|
||||
)
|
||||
try:
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
@@ -558,7 +553,7 @@ async def llm_call(
|
||||
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = {"type": "json_object"} if force_json_output else None
|
||||
response = await client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
@@ -580,7 +575,7 @@ async def llm_call(
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
response = await client.generate(
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
options={"num_ctx": max_tokens},
|
||||
@@ -610,7 +605,7 @@ async def llm_call(
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
@@ -652,7 +647,7 @@ async def llm_call(
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
@@ -690,7 +685,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
@@ -722,7 +717,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
|
||||
@@ -212,5 +212,72 @@ def get_dynamic_model_slugs() -> set[str]:
|
||||
return set(_dynamic_models.keys())
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> set[str]:
|
||||
"""
|
||||
Get ALL model slugs (both enabled and disabled) for validation purposes.
|
||||
|
||||
This is used for JSON schema enum validation - we need to accept any known
|
||||
model value (even disabled ones) so that existing graphs don't fail validation.
|
||||
The actual fallback/enforcement happens at runtime in llm_call().
|
||||
"""
|
||||
all_slugs = set(_dynamic_models.keys())
|
||||
all_slugs.update(_static_metadata.keys())
|
||||
return all_slugs
|
||||
|
||||
|
||||
def iter_dynamic_models() -> Iterable[RegistryModel]:
|
||||
return tuple(_dynamic_models.values())
|
||||
|
||||
|
||||
def get_fallback_model_for_disabled(disabled_model_slug: str) -> RegistryModel | None:
|
||||
"""
|
||||
Find a fallback model when the requested model is disabled.
|
||||
|
||||
Looks for an enabled model from the same provider. Prefers models with
|
||||
similar names or capabilities if possible.
|
||||
|
||||
Args:
|
||||
disabled_model_slug: The slug of the disabled model
|
||||
|
||||
Returns:
|
||||
An enabled RegistryModel from the same provider, or None if no fallback found
|
||||
"""
|
||||
disabled_model = _dynamic_models.get(disabled_model_slug)
|
||||
if not disabled_model:
|
||||
return None
|
||||
|
||||
provider = disabled_model.metadata.provider
|
||||
|
||||
# Find all enabled models from the same provider
|
||||
candidates = [
|
||||
model
|
||||
for model in _dynamic_models.values()
|
||||
if model.is_enabled and model.metadata.provider == provider
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Sort by: prefer models with similar context window, then by name
|
||||
candidates.sort(
|
||||
key=lambda m: (
|
||||
abs(m.metadata.context_window - disabled_model.metadata.context_window),
|
||||
m.display_name.lower(),
|
||||
)
|
||||
)
|
||||
|
||||
return candidates[0]
|
||||
|
||||
|
||||
def is_model_enabled(model_slug: str) -> bool:
|
||||
"""Check if a model is enabled in the registry."""
|
||||
model = _dynamic_models.get(model_slug)
|
||||
if not model:
|
||||
# Model not in registry - assume it's a static/legacy model and allow it
|
||||
return True
|
||||
return model.is_enabled
|
||||
|
||||
|
||||
def get_model_info(model_slug: str) -> RegistryModel | None:
|
||||
"""Get model info from the registry."""
|
||||
return _dynamic_models.get(model_slug)
|
||||
|
||||
@@ -32,25 +32,31 @@ def is_llm_model_field(field_name: str, field_info: Any) -> bool:
|
||||
|
||||
def refresh_llm_model_options(field_schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Refresh LLM model options and enum values from the registry.
|
||||
Refresh LLM model options from the registry.
|
||||
|
||||
Updates both 'options' (for frontend dropdown) and 'enum' (Pydantic validation)
|
||||
to reflect only currently enabled models.
|
||||
Updates 'options' (for frontend dropdown) to show only enabled models,
|
||||
but keeps the 'enum' (for validation) inclusive of ALL known models.
|
||||
|
||||
This is important because:
|
||||
- Options: What users see in the dropdown (enabled models only)
|
||||
- Enum: What values pass validation (all known models, including disabled)
|
||||
|
||||
Existing graphs may have disabled models selected - they should pass validation
|
||||
and the fallback logic in llm_call() will handle using an alternative model.
|
||||
"""
|
||||
fresh_options = llm_registry.get_llm_model_schema_options()
|
||||
if not fresh_options:
|
||||
return
|
||||
|
||||
enabled_slugs = {opt.get("value") for opt in fresh_options if isinstance(opt, dict)}
|
||||
|
||||
# Update options array
|
||||
# Update options array (UI dropdown) - only enabled models
|
||||
if "options" in field_schema:
|
||||
field_schema["options"] = fresh_options
|
||||
|
||||
# Filter enum to only enabled models
|
||||
if "enum" in field_schema and isinstance(field_schema.get("enum"), list):
|
||||
old_enum = field_schema["enum"]
|
||||
field_schema["enum"] = [val for val in old_enum if val in enabled_slugs]
|
||||
all_known_slugs = llm_registry.get_all_model_slugs_for_validation()
|
||||
if all_known_slugs and "enum" in field_schema:
|
||||
existing_enum = set(field_schema.get("enum", []))
|
||||
combined_enum = existing_enum | all_known_slugs
|
||||
field_schema["enum"] = sorted(combined_enum)
|
||||
|
||||
|
||||
def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
|
||||
|
||||
Reference in New Issue
Block a user