Refactor LLM registry integration and schema updates

Moved LLM registry schema update logic to a shared utility (llm_schema_utils.py) and refactored block and credentials schema post-processing to use this helper. Extracted executor registry initialization and notification handling into llm_registry_init.py for better separation of concerns. Updated manager.py to use new initialization and subscription functions, improving maintainability and clarity of LLM registry refresh logic.
This commit is contained in:
Bentlybro
2025-12-01 17:55:43 +00:00
parent 6bbeb22943
commit dfc42003a1
5 changed files with 202 additions and 299 deletions

View File

@@ -38,6 +38,7 @@ from backend.util.exceptions import (
)
from backend.util.settings import Config
from .llm_schema_utils import update_schema_with_llm_registry
from .model import (
ContributorDetails,
Credentials,
@@ -187,187 +188,11 @@ class BlockSchema(BaseModel):
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
# Always post-process to ensure discriminator is present for multi-provider credentials fields
# and refresh LLM model options to get latest enabled/disabled status
# This handles cases where the schema was generated before the registry was loaded or updated
# Note: We mutate the cached schema directly, which is safe since post-processing is idempotent
# IMPORTANT: We always refresh options even if schema was cached, to ensure disabled models are excluded
cls._ensure_discriminator_in_schema(cls.cached_jsonschema, cls)
# Always post-process to ensure LLM registry data is up-to-date
# This refreshes model options and discriminator mappings even if schema was cached
update_schema_with_llm_registry(cls.cached_jsonschema, cls)
return cls.cached_jsonschema
@staticmethod
def _ensure_discriminator_in_schema(schema: dict[str, Any], model_class: type | None = None) -> None:
"""Ensure discriminator is present in multi-provider credentials fields and refresh LLM model options."""
properties = schema.get("properties", {})
for field_name, field_schema in properties.items():
if not isinstance(field_schema, dict):
continue
# Check if this is an LLM model field by checking the field definition
is_llm_model_field = False
if model_class and hasattr(model_class, "model_fields") and field_name in model_class.model_fields:
try:
field_info = model_class.model_fields[field_name]
# Check if json_schema_extra has "options" (set by llm_model_schema_extra)
if hasattr(field_info, "json_schema_extra") and isinstance(field_info.json_schema_extra, dict):
if "options" in field_info.json_schema_extra:
is_llm_model_field = True
# Also check if the field type is LlmModel
if not is_llm_model_field and hasattr(field_info, "annotation"):
from backend.blocks.llm import LlmModel
from typing import get_origin, get_args
annotation = field_info.annotation
if annotation == LlmModel:
is_llm_model_field = True
else:
# Check for Optional[LlmModel] or Union types
origin = get_origin(annotation)
if origin:
args = get_args(annotation)
if LlmModel in args:
is_llm_model_field = True
except Exception:
pass
# Only refresh LLM model options for LLM model fields
# This prevents filtering other enum fields that aren't LLM models
if is_llm_model_field:
def refresh_options_in_schema(schema_part: dict[str, Any], path: str = "") -> bool:
"""Recursively refresh options in schema part. Returns True if options were found and refreshed."""
import logging
logger = logging.getLogger(__name__)
# Check for "options" key (used by frontend for select dropdowns)
has_options = "options" in schema_part and isinstance(schema_part.get("options"), list)
# Check for "enum" key (Pydantic generates this for Enum fields)
has_enum = "enum" in schema_part and isinstance(schema_part.get("enum"), list)
if has_options or has_enum:
try:
from backend.data import llm_registry
# Always refresh options from registry to get latest enabled/disabled status
fresh_options = llm_registry.get_llm_model_schema_options()
if fresh_options:
# Get enabled model slugs from fresh options
enabled_slugs = {opt.get("value") for opt in fresh_options if isinstance(opt, dict) and "value" in opt}
# Update "options" if present
if has_options:
old_count = len(schema_part["options"])
old_slugs = {opt.get("value") for opt in schema_part["options"] if isinstance(opt, dict) and "value" in opt}
schema_part["options"] = fresh_options
new_count = len(fresh_options)
new_slugs = enabled_slugs
# Log if there's a difference (models added/removed)
if old_count != new_count or old_slugs != new_slugs:
removed = old_slugs - new_slugs
added = new_slugs - old_slugs
if removed or added:
logger.info(
"Refreshed LLM model options for field %s%s: %d -> %d models. "
"Removed: %s, Added: %s",
field_name, f".{path}" if path else "", old_count, new_count, removed, added
)
# Update "enum" if present - filter to only enabled models
if has_enum:
old_enum = schema_part.get("enum", [])
# Filter enum values to only include enabled models
filtered_enum = [val for val in old_enum if val in enabled_slugs]
schema_part["enum"] = filtered_enum
if len(old_enum) != len(filtered_enum):
removed_enum = set(old_enum) - enabled_slugs
logger.info(
"Filtered LLM model enum for field %s%s: %d -> %d models. "
"Removed disabled: %s",
field_name, f".{path}" if path else "", len(old_enum), len(filtered_enum), removed_enum
)
return True
except Exception as e:
logger.warning("Failed to refresh LLM model options for field %s%s: %s", field_name, f".{path}" if path else "", e)
# Check nested structures
for key in ["anyOf", "oneOf", "allOf"]:
if key in schema_part and isinstance(schema_part[key], list):
for idx, item in enumerate(schema_part[key]):
if isinstance(item, dict) and refresh_options_in_schema(item, f"{path}.{key}[{idx}]" if path else f"{key}[{idx}]"):
return True
return False
refresh_options_in_schema(field_schema)
# Check if this is a credentials field - look for credentials_provider or credentials_types
has_credentials_provider = "credentials_provider" in field_schema
has_credentials_types = "credentials_types" in field_schema
if not (has_credentials_provider or has_credentials_types):
continue
# This is a credentials field
providers = field_schema.get("credentials_provider", [])
# If providers not in field schema yet, try to get from model class
if not providers and model_class and hasattr(model_class, "model_fields"):
try:
if field_name in model_class.model_fields:
field_info = model_class.model_fields[field_name]
# Check if this is a CredentialsMetaInput field
from backend.data.model import CredentialsMetaInput
if (hasattr(field_info, "annotation") and
inspect.isclass(field_info.annotation) and
issubclass(get_origin(field_info.annotation) or field_info.annotation, CredentialsMetaInput)):
# Get providers from the annotation
providers_list = CredentialsMetaInput.allowed_providers.__func__(field_info.annotation)
if providers_list:
providers = list(providers_list)
field_schema["credentials_provider"] = providers
except Exception:
pass
# Check if this is a multi-provider field
if isinstance(providers, list) and len(providers) > 1:
# Multi-provider field - ensure discriminator is set
if "discriminator" not in field_schema:
# Try to get discriminator from model field definition
discriminator_found = False
if model_class and hasattr(model_class, "model_fields") and field_name in model_class.model_fields:
try:
field_info = model_class.model_fields[field_name]
if hasattr(field_info, "json_schema_extra") and isinstance(field_info.json_schema_extra, dict):
discriminator = field_info.json_schema_extra.get("discriminator")
if discriminator:
field_schema["discriminator"] = discriminator
discriminator_found = True
except Exception:
pass
# If not found, check if this looks like an LLM field and default to "model"
if not discriminator_found:
llm_providers = {"openai", "anthropic", "groq", "open_router", "llama_api", "aiml_api", "v0", "ollama"}
if any(p in llm_providers for p in providers):
field_schema["discriminator"] = "model"
# If discriminator is "model", ensure discriminator_mapping is populated
if field_schema.get("discriminator") == "model":
mapping = field_schema.get("discriminator_mapping")
# If mapping is empty, missing, or None, refresh from registry
if not mapping or (isinstance(mapping, dict) and len(mapping) == 0):
try:
from backend.data import llm_registry
refreshed_mapping = llm_registry.get_llm_discriminator_mapping()
if refreshed_mapping:
field_schema["discriminator_mapping"] = refreshed_mapping
else:
# Ensure at least an empty dict is present
field_schema["discriminator_mapping"] = {}
except Exception:
if "discriminator_mapping" not in field_schema:
field_schema["discriminator_mapping"] = {}
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:

View File

@@ -0,0 +1,111 @@
"""
Helper utilities for LLM registry integration with block schemas.
This module handles the dynamic injection of discriminator mappings
and model options from the LLM registry into block schemas.
"""
import logging
from typing import Any
from backend.data import llm_registry
logger = logging.getLogger(__name__)
def is_llm_model_field(field_name: str, field_info: Any) -> bool:
"""
Check if a field is an LLM model selection field.
Returns True if the field has 'options' in json_schema_extra
(set by llm_model_schema_extra() in blocks/llm.py).
"""
if not hasattr(field_info, "json_schema_extra"):
return False
extra = field_info.json_schema_extra
if isinstance(extra, dict):
return "options" in extra
return False
def refresh_llm_model_options(field_schema: dict[str, Any]) -> None:
"""
Refresh LLM model options and enum values from the registry.
Updates both 'options' (for frontend dropdown) and 'enum' (Pydantic validation)
to reflect only currently enabled models.
"""
fresh_options = llm_registry.get_llm_model_schema_options()
if not fresh_options:
return
enabled_slugs = {opt.get("value") for opt in fresh_options if isinstance(opt, dict)}
# Update options array
if "options" in field_schema:
field_schema["options"] = fresh_options
# Filter enum to only enabled models
if "enum" in field_schema and isinstance(field_schema.get("enum"), list):
old_enum = field_schema["enum"]
field_schema["enum"] = [val for val in old_enum if val in enabled_slugs]
def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
"""
Refresh discriminator_mapping for fields that use model-based discrimination.
The discriminator is already set when AICredentialsField() creates the field.
We only need to refresh the mapping when models are added/removed.
"""
if field_schema.get("discriminator") != "model":
return
# Always refresh the mapping to get latest models
fresh_mapping = llm_registry.get_llm_discriminator_mapping()
if fresh_mapping:
field_schema["discriminator_mapping"] = fresh_mapping
def update_schema_with_llm_registry(
schema: dict[str, Any],
model_class: type | None = None
) -> None:
"""
Update a JSON schema with current LLM registry data.
Refreshes:
1. Model options for LLM model selection fields (dropdown choices)
2. Discriminator mappings for credentials fields (model → provider)
Args:
schema: The JSON schema to update (mutated in-place)
model_class: The Pydantic model class (optional, for field introspection)
"""
properties = schema.get("properties", {})
for field_name, field_schema in properties.items():
if not isinstance(field_schema, dict):
continue
# Refresh model options for LLM model fields
if model_class and hasattr(model_class, "model_fields"):
field_info = model_class.model_fields.get(field_name)
if field_info and is_llm_model_field(field_name, field_info):
try:
refresh_llm_model_options(field_schema)
except Exception as exc:
logger.warning(
"Failed to refresh LLM options for field %s: %s",
field_name, exc
)
# Refresh discriminator mapping for fields that use model discrimination
try:
refresh_llm_discriminator_mapping(field_schema)
except Exception as exc:
logger.warning(
"Failed to refresh discriminator mapping for field %s: %s",
field_name, exc
)

View File

@@ -44,6 +44,8 @@ from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads
from backend.util.settings import Secrets
from .llm_schema_utils import update_schema_with_llm_registry
# Type alias for any provider name (including custom ones)
AnyProviderName = str # Will be validated as ProviderName at runtime
@@ -532,61 +534,8 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
schema["credentials_provider"] = allowed_providers
schema["credentials_types"] = model_class.allowed_cred_types()
# For LLM credentials fields, ensure discriminator and discriminator_mapping are populated
# This handles the case where the mapping was empty at field definition time
# Check all properties for credentials fields
properties = schema.get("properties", {})
for field_name, field_schema in properties.items():
if not isinstance(field_schema, dict):
continue
# Check if this is a credentials field (has credentials_provider)
if "credentials_provider" in field_schema:
# Check if this field should have a discriminator (multiple providers)
providers = field_schema.get("credentials_provider", [])
if isinstance(providers, list) and len(providers) > 1:
# This is a multi-provider credentials field - ensure discriminator is set
if "discriminator" not in field_schema:
# Try to get discriminator from the field definition
# Check if this is an LLM credentials field by looking at the model fields
discriminator_found = False
try:
if hasattr(model_class, "model_fields") and field_name in model_class.model_fields:
field_info = model_class.model_fields[field_name]
# Check json_schema_extra on the FieldInfo
if hasattr(field_info, "json_schema_extra"):
if isinstance(field_info.json_schema_extra, dict):
discriminator = field_info.json_schema_extra.get("discriminator")
if discriminator:
field_schema["discriminator"] = discriminator
discriminator_found = True
except Exception:
pass
# If still not found and this looks like an LLM field (has multiple AI providers),
# default to "model" discriminator
if not discriminator_found:
# Check if providers include common LLM providers
llm_providers = {"openai", "anthropic", "groq", "open_router", "llama_api", "aiml_api", "v0"}
if any(p in llm_providers for p in providers):
field_schema["discriminator"] = "model"
# If discriminator is "model", ensure discriminator_mapping is populated
if field_schema.get("discriminator") == "model":
mapping = field_schema.get("discriminator_mapping")
# If mapping is empty, missing, or None, refresh from registry
if not mapping or (isinstance(mapping, dict) and len(mapping) == 0):
try:
from backend.data import llm_registry
refreshed_mapping = llm_registry.get_llm_discriminator_mapping()
if refreshed_mapping:
field_schema["discriminator_mapping"] = refreshed_mapping
except Exception:
# If registry isn't available, ensure at least an empty dict is present
if "discriminator_mapping" not in field_schema:
field_schema["discriminator_mapping"] = {}
# Do not return anything, just mutate schema in place
# Ensure LLM discriminators are populated (delegates to shared helper)
update_schema_with_llm_registry(schema, model_class)
model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore

View File

@@ -0,0 +1,70 @@
"""
Helper functions for LLM registry initialization in executor context.
These functions handle refreshing the LLM registry when the executor starts
and subscribing to real-time updates via Redis pub/sub.
"""
import logging
from backend.data import db, llm_registry
from backend.data.block import BlockSchema, initialize_blocks
from backend.data.block_cost_config import refresh_llm_costs
from backend.data.llm_registry_notifications import subscribe_to_registry_refresh
logger = logging.getLogger(__name__)
async def initialize_registry_for_executor() -> None:
"""
Initialize blocks and refresh LLM registry in the executor context.
This must run in the executor's event loop to have access to the database.
"""
try:
# Connect to database if not already connected
if not db.is_connected():
await db.connect()
logger.info("[GraphExecutor] Connected to database for registry refresh")
# Refresh LLM registry before initializing blocks
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
logger.info("[GraphExecutor] LLM registry refreshed")
# Initialize blocks (also refreshes registry, but we do it explicitly above)
await initialize_blocks()
logger.info("[GraphExecutor] Blocks initialized")
except Exception as exc:
logger.warning(
"[GraphExecutor] Failed to refresh LLM registry on startup: %s",
exc,
exc_info=True,
)
async def refresh_registry_on_notification() -> None:
"""Refresh LLM registry when notified via Redis pub/sub."""
try:
# Ensure DB is connected
if not db.is_connected():
await db.connect()
# Refresh registry and costs
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
# Clear block schema caches so they regenerate with new model options
BlockSchema.clear_all_schema_caches()
logger.info("[GraphExecutor] LLM registry refreshed from notification")
except Exception as exc:
logger.error(
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",
exc,
exc_info=True,
)
async def subscribe_to_registry_updates() -> None:
"""Subscribe to Redis pub/sub for LLM registry refresh notifications."""
await subscribe_to_registry_refresh(refresh_registry_on_notification)

View File

@@ -622,70 +622,18 @@ class ExecutionProcessor:
self.node_execution_thread.start()
self.node_evaluation_thread.start()
# Initialize blocks and refresh LLM registry in the execution loop
async def init_registry():
try:
from backend.data import db, llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.data.block import initialize_blocks
# Connect to database for registry refresh
if not db.is_connected():
await db.connect()
logger.info("[GraphExecutor] Connected to database for registry refresh")
# Refresh LLM registry before initializing blocks
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
logger.info("[GraphExecutor] LLM registry refreshed")
# Initialize blocks (this also refreshes registry, but we do it explicitly above)
await initialize_blocks()
logger.info("[GraphExecutor] Blocks initialized")
except Exception as exc:
logger.warning(
"[GraphExecutor] Failed to refresh LLM registry on startup: %s",
exc,
exc_info=True,
)
# Refresh registry function for notifications
async def refresh_registry_from_notification():
"""Refresh LLM registry when notified via Redis pub/sub"""
try:
from backend.data import db, llm_registry
from backend.data.block_cost_config import refresh_llm_costs
from backend.data.block import BlockSchema
# Ensure DB is connected
if not db.is_connected():
await db.connect()
# Refresh registry and costs
await llm_registry.refresh_llm_registry()
refresh_llm_costs()
# Clear block schema caches so they regenerate with new model options
BlockSchema.clear_all_schema_caches()
logger.info("[GraphExecutor] LLM registry refreshed from notification")
except Exception as exc:
logger.error(
"[GraphExecutor] Failed to refresh LLM registry from notification: %s",
exc,
exc_info=True,
)
# Schedule registry refresh in the execution loop
asyncio.run_coroutine_threadsafe(init_registry(), self.node_execution_loop)
# Subscribe to registry refresh notifications
async def subscribe_to_refresh():
from backend.data.llm_registry_notifications import subscribe_to_registry_refresh
await subscribe_to_registry_refresh(refresh_registry_from_notification)
# Start subscription in a background task
asyncio.run_coroutine_threadsafe(subscribe_to_refresh(), self.node_execution_loop)
# Initialize LLM registry and subscribe to updates
from backend.executor.llm_registry_init import (
initialize_registry_for_executor,
subscribe_to_registry_updates,
)
asyncio.run_coroutine_threadsafe(
initialize_registry_for_executor(), self.node_execution_loop
)
asyncio.run_coroutine_threadsafe(
subscribe_to_registry_updates(), self.node_execution_loop
)
logger.info(f"[GraphExecutor] {self.tid} started")