From dfc42003a1c99ccfee090be415bbc7de68ad5c5d Mon Sep 17 00:00:00 2001 From: Bentlybro Date: Mon, 1 Dec 2025 17:55:43 +0000 Subject: [PATCH] Refactor LLM registry integration and schema updates Moved LLM registry schema update logic to a shared utility (llm_schema_utils.py) and refactored block and credentials schema post-processing to use this helper. Extracted executor registry initialization and notification handling into llm_registry_init.py for better separation of concerns. Updated manager.py to use new initialization and subscription functions, improving maintainability and clarity of LLM registry refresh logic. --- .../backend/backend/data/block.py | 185 +----------------- .../backend/backend/data/llm_schema_utils.py | 111 +++++++++++ .../backend/backend/data/model.py | 59 +----- .../backend/executor/llm_registry_init.py | 70 +++++++ .../backend/backend/executor/manager.py | 76 ++----- 5 files changed, 202 insertions(+), 299 deletions(-) create mode 100644 autogpt_platform/backend/backend/data/llm_schema_utils.py create mode 100644 autogpt_platform/backend/backend/executor/llm_registry_init.py diff --git a/autogpt_platform/backend/backend/data/block.py b/autogpt_platform/backend/backend/data/block.py index 0bc5deb21e..001e0203da 100644 --- a/autogpt_platform/backend/backend/data/block.py +++ b/autogpt_platform/backend/backend/data/block.py @@ -38,6 +38,7 @@ from backend.util.exceptions import ( ) from backend.util.settings import Config +from .llm_schema_utils import update_schema_with_llm_registry from .model import ( ContributorDetails, Credentials, @@ -187,187 +188,11 @@ class BlockSchema(BaseModel): cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model)) - # Always post-process to ensure discriminator is present for multi-provider credentials fields - # and refresh LLM model options to get latest enabled/disabled status - # This handles cases where the schema was generated before the registry was loaded or updated - # Note: We mutate the cached schema directly, which is safe since post-processing is idempotent - # IMPORTANT: We always refresh options even if schema was cached, to ensure disabled models are excluded - cls._ensure_discriminator_in_schema(cls.cached_jsonschema, cls) - + # Always post-process to ensure LLM registry data is up-to-date + # This refreshes model options and discriminator mappings even if schema was cached + update_schema_with_llm_registry(cls.cached_jsonschema, cls) + return cls.cached_jsonschema - - @staticmethod - def _ensure_discriminator_in_schema(schema: dict[str, Any], model_class: type | None = None) -> None: - """Ensure discriminator is present in multi-provider credentials fields and refresh LLM model options.""" - properties = schema.get("properties", {}) - for field_name, field_schema in properties.items(): - if not isinstance(field_schema, dict): - continue - - # Check if this is an LLM model field by checking the field definition - is_llm_model_field = False - if model_class and hasattr(model_class, "model_fields") and field_name in model_class.model_fields: - try: - field_info = model_class.model_fields[field_name] - # Check if json_schema_extra has "options" (set by llm_model_schema_extra) - if hasattr(field_info, "json_schema_extra") and isinstance(field_info.json_schema_extra, dict): - if "options" in field_info.json_schema_extra: - is_llm_model_field = True - # Also check if the field type is LlmModel - if not is_llm_model_field and hasattr(field_info, "annotation"): - from backend.blocks.llm import LlmModel - from typing import get_origin, get_args - annotation = field_info.annotation - if annotation == LlmModel: - is_llm_model_field = True - else: - # Check for Optional[LlmModel] or Union types - origin = get_origin(annotation) - if origin: - args = get_args(annotation) - if LlmModel in args: - is_llm_model_field = True - except Exception: - pass - - # Only refresh LLM model options for LLM model fields - # This prevents filtering other enum fields that aren't LLM models - if is_llm_model_field: - def refresh_options_in_schema(schema_part: dict[str, Any], path: str = "") -> bool: - """Recursively refresh options in schema part. Returns True if options were found and refreshed.""" - import logging - logger = logging.getLogger(__name__) - - # Check for "options" key (used by frontend for select dropdowns) - has_options = "options" in schema_part and isinstance(schema_part.get("options"), list) - # Check for "enum" key (Pydantic generates this for Enum fields) - has_enum = "enum" in schema_part and isinstance(schema_part.get("enum"), list) - - if has_options or has_enum: - try: - from backend.data import llm_registry - # Always refresh options from registry to get latest enabled/disabled status - fresh_options = llm_registry.get_llm_model_schema_options() - if fresh_options: - # Get enabled model slugs from fresh options - enabled_slugs = {opt.get("value") for opt in fresh_options if isinstance(opt, dict) and "value" in opt} - - # Update "options" if present - if has_options: - old_count = len(schema_part["options"]) - old_slugs = {opt.get("value") for opt in schema_part["options"] if isinstance(opt, dict) and "value" in opt} - schema_part["options"] = fresh_options - new_count = len(fresh_options) - new_slugs = enabled_slugs - - # Log if there's a difference (models added/removed) - if old_count != new_count or old_slugs != new_slugs: - removed = old_slugs - new_slugs - added = new_slugs - old_slugs - if removed or added: - logger.info( - "Refreshed LLM model options for field %s%s: %d -> %d models. " - "Removed: %s, Added: %s", - field_name, f".{path}" if path else "", old_count, new_count, removed, added - ) - - # Update "enum" if present - filter to only enabled models - if has_enum: - old_enum = schema_part.get("enum", []) - # Filter enum values to only include enabled models - filtered_enum = [val for val in old_enum if val in enabled_slugs] - schema_part["enum"] = filtered_enum - - if len(old_enum) != len(filtered_enum): - removed_enum = set(old_enum) - enabled_slugs - logger.info( - "Filtered LLM model enum for field %s%s: %d -> %d models. " - "Removed disabled: %s", - field_name, f".{path}" if path else "", len(old_enum), len(filtered_enum), removed_enum - ) - - return True - except Exception as e: - logger.warning("Failed to refresh LLM model options for field %s%s: %s", field_name, f".{path}" if path else "", e) - - # Check nested structures - for key in ["anyOf", "oneOf", "allOf"]: - if key in schema_part and isinstance(schema_part[key], list): - for idx, item in enumerate(schema_part[key]): - if isinstance(item, dict) and refresh_options_in_schema(item, f"{path}.{key}[{idx}]" if path else f"{key}[{idx}]"): - return True - - return False - - refresh_options_in_schema(field_schema) - - # Check if this is a credentials field - look for credentials_provider or credentials_types - has_credentials_provider = "credentials_provider" in field_schema - has_credentials_types = "credentials_types" in field_schema - - if not (has_credentials_provider or has_credentials_types): - continue - - # This is a credentials field - providers = field_schema.get("credentials_provider", []) - - # If providers not in field schema yet, try to get from model class - if not providers and model_class and hasattr(model_class, "model_fields"): - try: - if field_name in model_class.model_fields: - field_info = model_class.model_fields[field_name] - # Check if this is a CredentialsMetaInput field - from backend.data.model import CredentialsMetaInput - if (hasattr(field_info, "annotation") and - inspect.isclass(field_info.annotation) and - issubclass(get_origin(field_info.annotation) or field_info.annotation, CredentialsMetaInput)): - # Get providers from the annotation - providers_list = CredentialsMetaInput.allowed_providers.__func__(field_info.annotation) - if providers_list: - providers = list(providers_list) - field_schema["credentials_provider"] = providers - except Exception: - pass - - # Check if this is a multi-provider field - if isinstance(providers, list) and len(providers) > 1: - # Multi-provider field - ensure discriminator is set - if "discriminator" not in field_schema: - # Try to get discriminator from model field definition - discriminator_found = False - if model_class and hasattr(model_class, "model_fields") and field_name in model_class.model_fields: - try: - field_info = model_class.model_fields[field_name] - if hasattr(field_info, "json_schema_extra") and isinstance(field_info.json_schema_extra, dict): - discriminator = field_info.json_schema_extra.get("discriminator") - if discriminator: - field_schema["discriminator"] = discriminator - discriminator_found = True - except Exception: - pass - - # If not found, check if this looks like an LLM field and default to "model" - if not discriminator_found: - llm_providers = {"openai", "anthropic", "groq", "open_router", "llama_api", "aiml_api", "v0", "ollama"} - if any(p in llm_providers for p in providers): - field_schema["discriminator"] = "model" - - # If discriminator is "model", ensure discriminator_mapping is populated - if field_schema.get("discriminator") == "model": - mapping = field_schema.get("discriminator_mapping") - # If mapping is empty, missing, or None, refresh from registry - if not mapping or (isinstance(mapping, dict) and len(mapping) == 0): - try: - from backend.data import llm_registry - refreshed_mapping = llm_registry.get_llm_discriminator_mapping() - if refreshed_mapping: - field_schema["discriminator_mapping"] = refreshed_mapping - else: - # Ensure at least an empty dict is present - field_schema["discriminator_mapping"] = {} - except Exception: - if "discriminator_mapping" not in field_schema: - field_schema["discriminator_mapping"] = {} @classmethod def validate_data(cls, data: BlockInput) -> str | None: diff --git a/autogpt_platform/backend/backend/data/llm_schema_utils.py b/autogpt_platform/backend/backend/data/llm_schema_utils.py new file mode 100644 index 0000000000..d0495749c3 --- /dev/null +++ b/autogpt_platform/backend/backend/data/llm_schema_utils.py @@ -0,0 +1,111 @@ +""" +Helper utilities for LLM registry integration with block schemas. + +This module handles the dynamic injection of discriminator mappings +and model options from the LLM registry into block schemas. +""" +import logging +from typing import Any + +from backend.data import llm_registry + +logger = logging.getLogger(__name__) + + +def is_llm_model_field(field_name: str, field_info: Any) -> bool: + """ + Check if a field is an LLM model selection field. + + Returns True if the field has 'options' in json_schema_extra + (set by llm_model_schema_extra() in blocks/llm.py). + """ + if not hasattr(field_info, "json_schema_extra"): + return False + + extra = field_info.json_schema_extra + if isinstance(extra, dict): + return "options" in extra + + return False + + +def refresh_llm_model_options(field_schema: dict[str, Any]) -> None: + """ + Refresh LLM model options and enum values from the registry. + + Updates both 'options' (for frontend dropdown) and 'enum' (Pydantic validation) + to reflect only currently enabled models. + """ + fresh_options = llm_registry.get_llm_model_schema_options() + if not fresh_options: + return + + enabled_slugs = {opt.get("value") for opt in fresh_options if isinstance(opt, dict)} + + # Update options array + if "options" in field_schema: + field_schema["options"] = fresh_options + + # Filter enum to only enabled models + if "enum" in field_schema and isinstance(field_schema.get("enum"), list): + old_enum = field_schema["enum"] + field_schema["enum"] = [val for val in old_enum if val in enabled_slugs] + + +def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None: + """ + Refresh discriminator_mapping for fields that use model-based discrimination. + + The discriminator is already set when AICredentialsField() creates the field. + We only need to refresh the mapping when models are added/removed. + """ + if field_schema.get("discriminator") != "model": + return + + # Always refresh the mapping to get latest models + fresh_mapping = llm_registry.get_llm_discriminator_mapping() + if fresh_mapping: + field_schema["discriminator_mapping"] = fresh_mapping + + +def update_schema_with_llm_registry( + schema: dict[str, Any], + model_class: type | None = None +) -> None: + """ + Update a JSON schema with current LLM registry data. + + Refreshes: + 1. Model options for LLM model selection fields (dropdown choices) + 2. Discriminator mappings for credentials fields (model → provider) + + Args: + schema: The JSON schema to update (mutated in-place) + model_class: The Pydantic model class (optional, for field introspection) + """ + properties = schema.get("properties", {}) + + for field_name, field_schema in properties.items(): + if not isinstance(field_schema, dict): + continue + + # Refresh model options for LLM model fields + if model_class and hasattr(model_class, "model_fields"): + field_info = model_class.model_fields.get(field_name) + if field_info and is_llm_model_field(field_name, field_info): + try: + refresh_llm_model_options(field_schema) + except Exception as exc: + logger.warning( + "Failed to refresh LLM options for field %s: %s", + field_name, exc + ) + + # Refresh discriminator mapping for fields that use model discrimination + try: + refresh_llm_discriminator_mapping(field_schema) + except Exception as exc: + logger.warning( + "Failed to refresh discriminator mapping for field %s: %s", + field_name, exc + ) diff --git a/autogpt_platform/backend/backend/data/model.py b/autogpt_platform/backend/backend/data/model.py index 8e89c74055..617b788ae4 100644 --- a/autogpt_platform/backend/backend/data/model.py +++ b/autogpt_platform/backend/backend/data/model.py @@ -44,6 +44,8 @@ from backend.integrations.providers import ProviderName from backend.util.json import loads as json_loads from backend.util.settings import Secrets +from .llm_schema_utils import update_schema_with_llm_registry + # Type alias for any provider name (including custom ones) AnyProviderName = str # Will be validated as ProviderName at runtime @@ -532,61 +534,8 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]): schema["credentials_provider"] = allowed_providers schema["credentials_types"] = model_class.allowed_cred_types() - # For LLM credentials fields, ensure discriminator and discriminator_mapping are populated - # This handles the case where the mapping was empty at field definition time - # Check all properties for credentials fields - properties = schema.get("properties", {}) - for field_name, field_schema in properties.items(): - if not isinstance(field_schema, dict): - continue - - # Check if this is a credentials field (has credentials_provider) - if "credentials_provider" in field_schema: - # Check if this field should have a discriminator (multiple providers) - providers = field_schema.get("credentials_provider", []) - if isinstance(providers, list) and len(providers) > 1: - # This is a multi-provider credentials field - ensure discriminator is set - if "discriminator" not in field_schema: - # Try to get discriminator from the field definition - # Check if this is an LLM credentials field by looking at the model fields - discriminator_found = False - try: - if hasattr(model_class, "model_fields") and field_name in model_class.model_fields: - field_info = model_class.model_fields[field_name] - # Check json_schema_extra on the FieldInfo - if hasattr(field_info, "json_schema_extra"): - if isinstance(field_info.json_schema_extra, dict): - discriminator = field_info.json_schema_extra.get("discriminator") - if discriminator: - field_schema["discriminator"] = discriminator - discriminator_found = True - except Exception: - pass - - # If still not found and this looks like an LLM field (has multiple AI providers), - # default to "model" discriminator - if not discriminator_found: - # Check if providers include common LLM providers - llm_providers = {"openai", "anthropic", "groq", "open_router", "llama_api", "aiml_api", "v0"} - if any(p in llm_providers for p in providers): - field_schema["discriminator"] = "model" - - # If discriminator is "model", ensure discriminator_mapping is populated - if field_schema.get("discriminator") == "model": - mapping = field_schema.get("discriminator_mapping") - # If mapping is empty, missing, or None, refresh from registry - if not mapping or (isinstance(mapping, dict) and len(mapping) == 0): - try: - from backend.data import llm_registry - refreshed_mapping = llm_registry.get_llm_discriminator_mapping() - if refreshed_mapping: - field_schema["discriminator_mapping"] = refreshed_mapping - except Exception: - # If registry isn't available, ensure at least an empty dict is present - if "discriminator_mapping" not in field_schema: - field_schema["discriminator_mapping"] = {} - - # Do not return anything, just mutate schema in place + # Ensure LLM discriminators are populated (delegates to shared helper) + update_schema_with_llm_registry(schema, model_class) model_config = ConfigDict( json_schema_extra=_add_json_schema_extra, # type: ignore diff --git a/autogpt_platform/backend/backend/executor/llm_registry_init.py b/autogpt_platform/backend/backend/executor/llm_registry_init.py new file mode 100644 index 0000000000..6143397250 --- /dev/null +++ b/autogpt_platform/backend/backend/executor/llm_registry_init.py @@ -0,0 +1,70 @@ +""" +Helper functions for LLM registry initialization in executor context. + +These functions handle refreshing the LLM registry when the executor starts +and subscribing to real-time updates via Redis pub/sub. +""" +import logging + +from backend.data import db, llm_registry +from backend.data.block import BlockSchema, initialize_blocks +from backend.data.block_cost_config import refresh_llm_costs +from backend.data.llm_registry_notifications import subscribe_to_registry_refresh + +logger = logging.getLogger(__name__) + + +async def initialize_registry_for_executor() -> None: + """ + Initialize blocks and refresh LLM registry in the executor context. + + This must run in the executor's event loop to have access to the database. + """ + try: + # Connect to database if not already connected + if not db.is_connected(): + await db.connect() + logger.info("[GraphExecutor] Connected to database for registry refresh") + + # Refresh LLM registry before initializing blocks + await llm_registry.refresh_llm_registry() + refresh_llm_costs() + logger.info("[GraphExecutor] LLM registry refreshed") + + # Initialize blocks (also refreshes registry, but we do it explicitly above) + await initialize_blocks() + logger.info("[GraphExecutor] Blocks initialized") + except Exception as exc: + logger.warning( + "[GraphExecutor] Failed to refresh LLM registry on startup: %s", + exc, + exc_info=True, + ) + + +async def refresh_registry_on_notification() -> None: + """Refresh LLM registry when notified via Redis pub/sub.""" + try: + # Ensure DB is connected + if not db.is_connected(): + await db.connect() + + # Refresh registry and costs + await llm_registry.refresh_llm_registry() + refresh_llm_costs() + + # Clear block schema caches so they regenerate with new model options + BlockSchema.clear_all_schema_caches() + + logger.info("[GraphExecutor] LLM registry refreshed from notification") + except Exception as exc: + logger.error( + "[GraphExecutor] Failed to refresh LLM registry from notification: %s", + exc, + exc_info=True, + ) + + +async def subscribe_to_registry_updates() -> None: + """Subscribe to Redis pub/sub for LLM registry refresh notifications.""" + await subscribe_to_registry_refresh(refresh_registry_on_notification) diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index 80eb678800..cb24f8c9af 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -622,70 +622,18 @@ class ExecutionProcessor: self.node_execution_thread.start() self.node_evaluation_thread.start() - # Initialize blocks and refresh LLM registry in the execution loop - async def init_registry(): - try: - from backend.data import db, llm_registry - from backend.data.block_cost_config import refresh_llm_costs - from backend.data.block import initialize_blocks - - # Connect to database for registry refresh - if not db.is_connected(): - await db.connect() - logger.info("[GraphExecutor] Connected to database for registry refresh") - - # Refresh LLM registry before initializing blocks - await llm_registry.refresh_llm_registry() - refresh_llm_costs() - logger.info("[GraphExecutor] LLM registry refreshed") - - # Initialize blocks (this also refreshes registry, but we do it explicitly above) - await initialize_blocks() - logger.info("[GraphExecutor] Blocks initialized") - except Exception as exc: - logger.warning( - "[GraphExecutor] Failed to refresh LLM registry on startup: %s", - exc, - exc_info=True, - ) - - # Refresh registry function for notifications - async def refresh_registry_from_notification(): - """Refresh LLM registry when notified via Redis pub/sub""" - try: - from backend.data import db, llm_registry - from backend.data.block_cost_config import refresh_llm_costs - from backend.data.block import BlockSchema - - # Ensure DB is connected - if not db.is_connected(): - await db.connect() - - # Refresh registry and costs - await llm_registry.refresh_llm_registry() - refresh_llm_costs() - - # Clear block schema caches so they regenerate with new model options - BlockSchema.clear_all_schema_caches() - - logger.info("[GraphExecutor] LLM registry refreshed from notification") - except Exception as exc: - logger.error( - "[GraphExecutor] Failed to refresh LLM registry from notification: %s", - exc, - exc_info=True, - ) - - # Schedule registry refresh in the execution loop - asyncio.run_coroutine_threadsafe(init_registry(), self.node_execution_loop) - - # Subscribe to registry refresh notifications - async def subscribe_to_refresh(): - from backend.data.llm_registry_notifications import subscribe_to_registry_refresh - await subscribe_to_registry_refresh(refresh_registry_from_notification) - - # Start subscription in a background task - asyncio.run_coroutine_threadsafe(subscribe_to_refresh(), self.node_execution_loop) + # Initialize LLM registry and subscribe to updates + from backend.executor.llm_registry_init import ( + initialize_registry_for_executor, + subscribe_to_registry_updates, + ) + + asyncio.run_coroutine_threadsafe( + initialize_registry_for_executor(), self.node_execution_loop + ) + asyncio.run_coroutine_threadsafe( + subscribe_to_registry_updates(), self.node_execution_loop + ) logger.info(f"[GraphExecutor] {self.tid} started")