mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(registry): address Majdyz review - extract helper, fix schema prefix, return copies, remove re-export
This commit is contained in:
@@ -137,10 +137,17 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
try:
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to migrate LLM models at startup: {e}. "
|
||||
"This is expected in test environments without AgentNode table."
|
||||
)
|
||||
err_str = str(e)
|
||||
if "AgentNode" in err_str or "does not exist" in err_str:
|
||||
logger.warning(
|
||||
f"migrate_llm_models skipped: AgentNode table not found ({e}). "
|
||||
"This is expected in test environments."
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"migrate_llm_models failed unexpectedly: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ from backend.util.request import parse_url
|
||||
from .block import BlockInput
|
||||
from .db import BaseDbModel, execute_raw_with_schema
|
||||
from .db import prisma as db
|
||||
from .db import query_raw_with_schema, transaction
|
||||
from .db import execute_raw_with_schema, query_raw_with_schema, transaction
|
||||
from .dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, MAX_GRAPH_VERSIONS_FETCH
|
||||
from .model import CredentialsFieldInfo, CredentialsMetaInput, is_credentials_field_name
|
||||
@@ -1667,15 +1667,12 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
|
||||
# Update each block
|
||||
for id, path in llm_model_fields.items():
|
||||
query = (
|
||||
"""
|
||||
query = """
|
||||
UPDATE {schema_prefix}"AgentNode"
|
||||
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
|
||||
WHERE "agentBlockId" = $3
|
||||
AND "constantInput" ? ($4)::text
|
||||
AND "constantInput"->>($4)::text NOT IN """
|
||||
+ enum_values
|
||||
)
|
||||
AND "constantInput"->>($4)::text NOT IN """ + enum_values
|
||||
|
||||
await execute_raw_with_schema(
|
||||
query,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""LLM Registry - Dynamic model management system."""
|
||||
|
||||
from .model import ModelMetadata
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
from .registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""Type definitions for LLM model metadata.
|
||||
|
||||
Re-exports ModelMetadata from blocks.llm to avoid type collision.
|
||||
In PR #5 (block integration), this will become the canonical location.
|
||||
"""
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
|
||||
__all__ = ["ModelMetadata"]
|
||||
@@ -9,7 +9,7 @@ from typing import Any
|
||||
|
||||
import prisma.models
|
||||
|
||||
from backend.data.llm_registry.model import ModelMetadata
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -62,6 +62,95 @@ _schema_options: list[dict[str, str]] = []
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel:
|
||||
"""Transform a raw Prisma LlmModel record into a RegistryModel instance."""
|
||||
# Parse costs
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
unit=str(cost.unit), # Convert enum to string
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=dict(cost.metadata or {}),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
# Parse creator
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
# Parse capabilities
|
||||
capabilities = dict(record.capabilities or {})
|
||||
|
||||
# Build metadata from record
|
||||
# Warn if Provider relation is missing (indicates data corruption)
|
||||
if not record.Provider:
|
||||
logger.warning(
|
||||
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
|
||||
f"falling back to providerId {record.providerId}"
|
||||
)
|
||||
provider_name = (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
)
|
||||
provider_display = (
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId
|
||||
)
|
||||
|
||||
# Extract creator name (fallback to "Unknown" if no creator)
|
||||
creator_name = (
|
||||
record.Creator.displayName if record.Creator else "Unknown"
|
||||
)
|
||||
|
||||
# Price tier defaults to 1 if not set
|
||||
if record.priceTier not in (1, 2, 3):
|
||||
logger.warning(
|
||||
f"LlmModel {record.slug} has out-of-range priceTier={record.priceTier}, "
|
||||
"defaulting to 1"
|
||||
)
|
||||
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=(
|
||||
record.maxOutputTokens
|
||||
if record.maxOutputTokens is not None
|
||||
else record.contextWindow
|
||||
),
|
||||
display_name=record.displayName,
|
||||
provider_name=provider_display,
|
||||
creator_name=creator_name,
|
||||
price_tier=price_tier,
|
||||
)
|
||||
|
||||
return RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=capabilities,
|
||||
extra_metadata=dict(record.metadata or {}),
|
||||
provider_display_name=provider_display,
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
)
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""
|
||||
Refresh the LLM registry from the database.
|
||||
@@ -83,87 +172,7 @@ async def refresh_llm_registry() -> None:
|
||||
# Build model instances
|
||||
new_models: dict[str, RegistryModel] = {}
|
||||
for record in records:
|
||||
# Parse costs
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
unit=str(cost.unit), # Convert enum to string
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=dict(cost.metadata or {}),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
# Parse creator
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
# Parse capabilities
|
||||
capabilities = dict(record.capabilities or {})
|
||||
|
||||
# Build metadata from record
|
||||
# Warn if Provider relation is missing (indicates data corruption)
|
||||
if not record.Provider:
|
||||
logger.warning(
|
||||
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
|
||||
f"falling back to providerId {record.providerId}"
|
||||
)
|
||||
provider_name = (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
)
|
||||
provider_display = (
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId
|
||||
)
|
||||
|
||||
# Extract creator name (fallback to "Unknown" if no creator)
|
||||
creator_name = (
|
||||
record.Creator.displayName if record.Creator else "Unknown"
|
||||
)
|
||||
|
||||
# Price tier defaults to 1 if not set
|
||||
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=(
|
||||
record.maxOutputTokens
|
||||
if record.maxOutputTokens is not None
|
||||
else record.contextWindow
|
||||
),
|
||||
display_name=record.displayName,
|
||||
provider_name=provider_display,
|
||||
creator_name=creator_name,
|
||||
price_tier=price_tier,
|
||||
)
|
||||
|
||||
# Create model instance
|
||||
model = RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=capabilities,
|
||||
extra_metadata=dict(record.metadata or {}),
|
||||
provider_display_name=provider_display,
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
)
|
||||
model = _record_to_registry_model(record)
|
||||
new_models[record.slug] = model
|
||||
|
||||
# Atomic swap
|
||||
@@ -213,7 +222,7 @@ def get_enabled_models() -> list[RegistryModel]:
|
||||
|
||||
def get_schema_options() -> list[dict[str, str]]:
|
||||
"""Get schema options for model selection dropdown (enabled models only)."""
|
||||
return _schema_options
|
||||
return list(_schema_options)
|
||||
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
|
||||
Reference in New Issue
Block a user