Add fallback logic for disabled LLM models

Introduces fallback selection for disabled LLM models in llm_call, preferring enabled models from the same provider. Updates registry utilities to support fallback lookup, model info retrieval, and validation of all known model slugs. Schema utilities now keep all known models in validation enums while showing only enabled models in UI options.
This commit is contained in:
Bentlybro
2025-12-08 11:29:31 +00:00
parent a97fdba554
commit 7435739053
3 changed files with 164 additions and 96 deletions

View File

@@ -337,93 +337,88 @@ async def llm_call(
- prompt_tokens: The number of tokens used in the prompt.
- completion_tokens: The number of tokens used in the completion.
"""
# Get model metadata - try cache first, then fallback to async lookup
# Also check if the model is enabled
try:
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
# Get model metadata and check if enabled - with fallback support
# The model we'll actually use (may differ if original is disabled)
model_to_use = llm_model.value
# Check if model is enabled - get from registry
from backend.data.llm_registry import _dynamic_models
# Check if model is in registry and if it's enabled
from backend.data.llm_registry import (
get_fallback_model_for_disabled,
get_model_info,
)
if llm_model.value in _dynamic_models:
model_info = _dynamic_models[llm_model.value]
if not model_info.is_enabled:
raise ValueError(f"LLM model '{llm_model.value}' is disabled.")
except ValueError as e:
# Re-raise if it's our disabled model error
if "is disabled" in str(e):
raise
# Model not in cache - try refreshing the registry once if we have DB access
import logging
model_info = get_model_info(llm_model.value)
logger = logging.getLogger(__name__)
logger.warning(
"Model %s not found in registry cache",
llm_model.value,
)
# Try refreshing the registry if we have database access
from backend.data.db import is_connected
if is_connected():
try:
logger.info(
"Refreshing LLM registry and retrying lookup for %s",
llm_model.value,
)
await llm_registry.refresh_llm_registry()
# Try again after refresh
try:
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
# Check if model is enabled after refresh
from backend.data.llm_registry import _dynamic_models
if llm_model.value in _dynamic_models:
model_info = _dynamic_models[llm_model.value]
if not model_info.is_enabled:
raise ValueError(
f"LLM model '{llm_model.value}' is disabled. "
"Please enable it in the LLM registry via the admin UI to use this model."
)
logger.info(
"Successfully loaded model %s metadata after registry refresh",
llm_model.value,
)
except ValueError as ve:
# Re-raise if it's our disabled model error
if "is disabled" in str(ve):
raise
# Still not found after refresh
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry after refresh. "
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
)
except Exception as refresh_exc:
logger.error(
"Failed to refresh LLM registry: %s", refresh_exc, exc_info=True
)
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
"Please ensure the model is added to the LLM registry via the admin UI."
) from refresh_exc
if model_info and not model_info.is_enabled:
# Model is disabled - try to find a fallback from the same provider
fallback = get_fallback_model_for_disabled(llm_model.value)
if fallback:
logger.warning(
f"Model '{llm_model.value}' is disabled. Using fallback model '{fallback.slug}' from the same provider ({fallback.metadata.provider})."
)
model_to_use = fallback.slug
# Use fallback model's metadata
provider = fallback.metadata.provider
context_window = fallback.metadata.context_window
model_max_output = fallback.metadata.max_output_tokens or int(2**15)
else:
# No DB access (e.g., in executor without direct DB connection)
# The registry should have been loaded on startup
# No fallback available - raise error
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry cache. "
"The registry may need to be refreshed. Please contact support or try again later."
) from e
f"LLM model '{llm_model.value}' is disabled and no fallback model "
f"from the same provider is available. Please enable the model or "
f"select a different model in the block configuration."
)
else:
# Model is enabled or not in registry (legacy/static model)
try:
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
except ValueError:
# Model not in cache - try refreshing the registry once if we have DB access
logger.warning(f"Model {llm_model.value} not found in registry cache")
# Try refreshing the registry if we have database access
from backend.data.db import is_connected
if is_connected():
try:
logger.info(
f"Refreshing LLM registry and retrying lookup for {llm_model.value}"
)
await llm_registry.refresh_llm_registry()
# Try again after refresh
try:
provider = llm_model.metadata.provider
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
logger.info(
f"Successfully loaded model {llm_model.value} metadata after registry refresh"
)
except ValueError:
# Still not found after refresh
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry after refresh. "
"Please ensure the model is added and enabled in the LLM registry via the admin UI."
)
except Exception as refresh_exc:
logger.error(f"Failed to refresh LLM registry: {refresh_exc}")
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry and failed to refresh. "
"Please ensure the model is added to the LLM registry via the admin UI."
) from refresh_exc
else:
# No DB access (e.g., in executor without direct DB connection)
# The registry should have been loaded on startup
raise ValueError(
f"LLM model '{llm_model.value}' not found in registry cache. "
"The registry may need to be refreshed. Please contact support or try again later."
)
if compress_prompt_to_fit:
prompt = compress_prompt(
messages=prompt,
target_tokens=llm_model.context_window // 2,
target_tokens=context_window // 2,
lossy_ok=True,
)
@@ -447,7 +442,7 @@ async def llm_call(
response_format = {"type": "json_object"}
response = await oai_client.chat.completions.create(
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
@@ -494,7 +489,7 @@ async def llm_call(
)
try:
resp = await client.messages.create(
model=llm_model.value,
model=model_to_use,
system=sysprompt,
messages=messages,
max_tokens=max_tokens,
@@ -558,7 +553,7 @@ async def llm_call(
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
response_format = {"type": "json_object"} if force_json_output else None
response = await client.chat.completions.create(
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
@@ -580,7 +575,7 @@ async def llm_call(
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = await client.generate(
model=llm_model.value,
model=model_to_use,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
options={"num_ctx": max_tokens},
@@ -610,7 +605,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
@@ -652,7 +647,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore
@@ -690,7 +685,7 @@ async def llm_call(
)
completion = client.chat.completions.create(
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
max_tokens=max_tokens,
)
@@ -722,7 +717,7 @@ async def llm_call(
)
response = await client.chat.completions.create(
model=llm_model.value,
model=model_to_use,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,

View File

@@ -212,5 +212,72 @@ def get_dynamic_model_slugs() -> set[str]:
return set(_dynamic_models.keys())
def get_all_model_slugs_for_validation() -> set[str]:
"""
Get ALL model slugs (both enabled and disabled) for validation purposes.
This is used for JSON schema enum validation - we need to accept any known
model value (even disabled ones) so that existing graphs don't fail validation.
The actual fallback/enforcement happens at runtime in llm_call().
"""
all_slugs = set(_dynamic_models.keys())
all_slugs.update(_static_metadata.keys())
return all_slugs
def iter_dynamic_models() -> Iterable[RegistryModel]:
return tuple(_dynamic_models.values())
def get_fallback_model_for_disabled(disabled_model_slug: str) -> RegistryModel | None:
"""
Find a fallback model when the requested model is disabled.
Looks for an enabled model from the same provider. Prefers models with
similar names or capabilities if possible.
Args:
disabled_model_slug: The slug of the disabled model
Returns:
An enabled RegistryModel from the same provider, or None if no fallback found
"""
disabled_model = _dynamic_models.get(disabled_model_slug)
if not disabled_model:
return None
provider = disabled_model.metadata.provider
# Find all enabled models from the same provider
candidates = [
model
for model in _dynamic_models.values()
if model.is_enabled and model.metadata.provider == provider
]
if not candidates:
return None
# Sort by: prefer models with similar context window, then by name
candidates.sort(
key=lambda m: (
abs(m.metadata.context_window - disabled_model.metadata.context_window),
m.display_name.lower(),
)
)
return candidates[0]
def is_model_enabled(model_slug: str) -> bool:
"""Check if a model is enabled in the registry."""
model = _dynamic_models.get(model_slug)
if not model:
# Model not in registry - assume it's a static/legacy model and allow it
return True
return model.is_enabled
def get_model_info(model_slug: str) -> RegistryModel | None:
"""Get model info from the registry."""
return _dynamic_models.get(model_slug)

View File

@@ -32,25 +32,31 @@ def is_llm_model_field(field_name: str, field_info: Any) -> bool:
def refresh_llm_model_options(field_schema: dict[str, Any]) -> None:
"""
Refresh LLM model options and enum values from the registry.
Refresh LLM model options from the registry.
Updates both 'options' (for frontend dropdown) and 'enum' (Pydantic validation)
to reflect only currently enabled models.
Updates 'options' (for frontend dropdown) to show only enabled models,
but keeps the 'enum' (for validation) inclusive of ALL known models.
This is important because:
- Options: What users see in the dropdown (enabled models only)
- Enum: What values pass validation (all known models, including disabled)
Existing graphs may have disabled models selected - they should pass validation
and the fallback logic in llm_call() will handle using an alternative model.
"""
fresh_options = llm_registry.get_llm_model_schema_options()
if not fresh_options:
return
enabled_slugs = {opt.get("value") for opt in fresh_options if isinstance(opt, dict)}
# Update options array
# Update options array (UI dropdown) - only enabled models
if "options" in field_schema:
field_schema["options"] = fresh_options
# Filter enum to only enabled models
if "enum" in field_schema and isinstance(field_schema.get("enum"), list):
old_enum = field_schema["enum"]
field_schema["enum"] = [val for val in old_enum if val in enabled_slugs]
all_known_slugs = llm_registry.get_all_model_slugs_for_validation()
if all_known_slugs and "enum" in field_schema:
existing_enum = set(field_schema.get("enum", []))
combined_enum = existing_enum | all_known_slugs
field_schema["enum"] = sorted(combined_enum)
def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None: