diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index cd8ede47aa..771b6a8fc2 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -137,10 +137,17 @@ async def lifespan_context(app: fastapi.FastAPI): try: await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL) except Exception as e: - logger.warning( - f"Failed to migrate LLM models at startup: {e}. " - "This is expected in test environments without AgentNode table." - ) + err_str = str(e) + if "AgentNode" in err_str or "does not exist" in err_str: + logger.warning( + f"migrate_llm_models skipped: AgentNode table not found ({e}). " + "This is expected in test environments." + ) + else: + logger.error( + f"migrate_llm_models failed unexpectedly: {e}", + exc_info=True, + ) await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs() diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index ab75ed2f42..7eb29eedf5 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -38,7 +38,7 @@ from backend.util.request import parse_url from .block import BlockInput from .db import BaseDbModel, execute_raw_with_schema from .db import prisma as db -from .db import query_raw_with_schema, transaction +from .db import execute_raw_with_schema, query_raw_with_schema, transaction from .dynamic_fields import is_tool_pin, sanitize_pin_name from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, MAX_GRAPH_VERSIONS_FETCH from .model import CredentialsFieldInfo, CredentialsMetaInput, is_credentials_field_name @@ -1667,15 +1667,12 @@ async def migrate_llm_models(migrate_to: LlmModel): # Update each block for id, path in llm_model_fields.items(): - query = ( - """ + query = """ UPDATE {schema_prefix}"AgentNode" SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true) WHERE "agentBlockId" = $3 AND "constantInput" ? ($4)::text - AND "constantInput"->>($4)::text NOT IN """ - + enum_values - ) + AND "constantInput"->>($4)::text NOT IN """ + enum_values await execute_raw_with_schema( query, diff --git a/autogpt_platform/backend/backend/data/llm_registry/__init__.py b/autogpt_platform/backend/backend/data/llm_registry/__init__.py index de18f00fc0..3e6b9896b3 100644 --- a/autogpt_platform/backend/backend/data/llm_registry/__init__.py +++ b/autogpt_platform/backend/backend/data/llm_registry/__init__.py @@ -1,6 +1,6 @@ """LLM Registry - Dynamic model management system.""" -from .model import ModelMetadata +from backend.blocks.llm import ModelMetadata from .registry import ( RegistryModel, RegistryModelCost, diff --git a/autogpt_platform/backend/backend/data/llm_registry/model.py b/autogpt_platform/backend/backend/data/llm_registry/model.py deleted file mode 100644 index 18d4ab7356..0000000000 --- a/autogpt_platform/backend/backend/data/llm_registry/model.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Type definitions for LLM model metadata. - -Re-exports ModelMetadata from blocks.llm to avoid type collision. -In PR #5 (block integration), this will become the canonical location. -""" - -from backend.blocks.llm import ModelMetadata - -__all__ = ["ModelMetadata"] diff --git a/autogpt_platform/backend/backend/data/llm_registry/registry.py b/autogpt_platform/backend/backend/data/llm_registry/registry.py index 1a91b36154..5a2eb000ff 100644 --- a/autogpt_platform/backend/backend/data/llm_registry/registry.py +++ b/autogpt_platform/backend/backend/data/llm_registry/registry.py @@ -9,7 +9,7 @@ from typing import Any import prisma.models -from backend.data.llm_registry.model import ModelMetadata +from backend.blocks.llm import ModelMetadata logger = logging.getLogger(__name__) @@ -62,6 +62,95 @@ _schema_options: list[dict[str, str]] = [] _lock = asyncio.Lock() +def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel: + """Transform a raw Prisma LlmModel record into a RegistryModel instance.""" + # Parse costs + costs = tuple( + RegistryModelCost( + unit=str(cost.unit), # Convert enum to string + credit_cost=cost.creditCost, + credential_provider=cost.credentialProvider, + credential_id=cost.credentialId, + credential_type=cost.credentialType, + currency=cost.currency, + metadata=dict(cost.metadata or {}), + ) + for cost in (record.Costs or []) + ) + + # Parse creator + creator = None + if record.Creator: + creator = RegistryModelCreator( + id=record.Creator.id, + name=record.Creator.name, + display_name=record.Creator.displayName, + description=record.Creator.description, + website_url=record.Creator.websiteUrl, + logo_url=record.Creator.logoUrl, + ) + + # Parse capabilities + capabilities = dict(record.capabilities or {}) + + # Build metadata from record + # Warn if Provider relation is missing (indicates data corruption) + if not record.Provider: + logger.warning( + f"LlmModel {record.slug} has no Provider despite NOT NULL FK - " + f"falling back to providerId {record.providerId}" + ) + provider_name = ( + record.Provider.name if record.Provider else record.providerId + ) + provider_display = ( + record.Provider.displayName + if record.Provider + else record.providerId + ) + + # Extract creator name (fallback to "Unknown" if no creator) + creator_name = ( + record.Creator.displayName if record.Creator else "Unknown" + ) + + # Price tier defaults to 1 if not set + if record.priceTier not in (1, 2, 3): + logger.warning( + f"LlmModel {record.slug} has out-of-range priceTier={record.priceTier}, " + "defaulting to 1" + ) + price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1 + + metadata = ModelMetadata( + provider=provider_name, + context_window=record.contextWindow, + max_output_tokens=( + record.maxOutputTokens + if record.maxOutputTokens is not None + else record.contextWindow + ), + display_name=record.displayName, + provider_name=provider_display, + creator_name=creator_name, + price_tier=price_tier, + ) + + return RegistryModel( + slug=record.slug, + display_name=record.displayName, + description=record.description, + metadata=metadata, + capabilities=capabilities, + extra_metadata=dict(record.metadata or {}), + provider_display_name=provider_display, + is_enabled=record.isEnabled, + is_recommended=record.isRecommended, + costs=costs, + creator=creator, + ) + + async def refresh_llm_registry() -> None: """ Refresh the LLM registry from the database. @@ -83,87 +172,7 @@ async def refresh_llm_registry() -> None: # Build model instances new_models: dict[str, RegistryModel] = {} for record in records: - # Parse costs - costs = tuple( - RegistryModelCost( - unit=str(cost.unit), # Convert enum to string - credit_cost=cost.creditCost, - credential_provider=cost.credentialProvider, - credential_id=cost.credentialId, - credential_type=cost.credentialType, - currency=cost.currency, - metadata=dict(cost.metadata or {}), - ) - for cost in (record.Costs or []) - ) - - # Parse creator - creator = None - if record.Creator: - creator = RegistryModelCreator( - id=record.Creator.id, - name=record.Creator.name, - display_name=record.Creator.displayName, - description=record.Creator.description, - website_url=record.Creator.websiteUrl, - logo_url=record.Creator.logoUrl, - ) - - # Parse capabilities - capabilities = dict(record.capabilities or {}) - - # Build metadata from record - # Warn if Provider relation is missing (indicates data corruption) - if not record.Provider: - logger.warning( - f"LlmModel {record.slug} has no Provider despite NOT NULL FK - " - f"falling back to providerId {record.providerId}" - ) - provider_name = ( - record.Provider.name if record.Provider else record.providerId - ) - provider_display = ( - record.Provider.displayName - if record.Provider - else record.providerId - ) - - # Extract creator name (fallback to "Unknown" if no creator) - creator_name = ( - record.Creator.displayName if record.Creator else "Unknown" - ) - - # Price tier defaults to 1 if not set - price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1 - - metadata = ModelMetadata( - provider=provider_name, - context_window=record.contextWindow, - max_output_tokens=( - record.maxOutputTokens - if record.maxOutputTokens is not None - else record.contextWindow - ), - display_name=record.displayName, - provider_name=provider_display, - creator_name=creator_name, - price_tier=price_tier, - ) - - # Create model instance - model = RegistryModel( - slug=record.slug, - display_name=record.displayName, - description=record.description, - metadata=metadata, - capabilities=capabilities, - extra_metadata=dict(record.metadata or {}), - provider_display_name=provider_display, - is_enabled=record.isEnabled, - is_recommended=record.isRecommended, - costs=costs, - creator=creator, - ) + model = _record_to_registry_model(record) new_models[record.slug] = model # Atomic swap @@ -213,7 +222,7 @@ def get_enabled_models() -> list[RegistryModel]: def get_schema_options() -> list[dict[str, str]]: """Get schema options for model selection dropdown (enabled models only).""" - return _schema_options + return list(_schema_options) def get_default_model_slug() -> str | None: