fix(registry): address Majdyz review - extract helper, fix schema prefix, return copies, remove re-export

This commit is contained in:
Bentlybro
2026-04-04 20:37:50 +00:00
parent 7e85371ce5
commit 732365cd8f
5 changed files with 107 additions and 103 deletions

View File

@@ -137,10 +137,17 @@ async def lifespan_context(app: fastapi.FastAPI):
try:
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
except Exception as e:
logger.warning(
f"Failed to migrate LLM models at startup: {e}. "
"This is expected in test environments without AgentNode table."
)
err_str = str(e)
if "AgentNode" in err_str or "does not exist" in err_str:
logger.warning(
f"migrate_llm_models skipped: AgentNode table not found ({e}). "
"This is expected in test environments."
)
else:
logger.error(
f"migrate_llm_models failed unexpectedly: {e}",
exc_info=True,
)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()

View File

@@ -38,7 +38,7 @@ from backend.util.request import parse_url
from .block import BlockInput
from .db import BaseDbModel, execute_raw_with_schema
from .db import prisma as db
from .db import query_raw_with_schema, transaction
from .db import execute_raw_with_schema, query_raw_with_schema, transaction
from .dynamic_fields import is_tool_pin, sanitize_pin_name
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, MAX_GRAPH_VERSIONS_FETCH
from .model import CredentialsFieldInfo, CredentialsMetaInput, is_credentials_field_name
@@ -1667,15 +1667,12 @@ async def migrate_llm_models(migrate_to: LlmModel):
# Update each block
for id, path in llm_model_fields.items():
query = (
"""
query = """
UPDATE {schema_prefix}"AgentNode"
SET "constantInput" = jsonb_set("constantInput", $1, to_jsonb($2), true)
WHERE "agentBlockId" = $3
AND "constantInput" ? ($4)::text
AND "constantInput"->>($4)::text NOT IN """
+ enum_values
)
AND "constantInput"->>($4)::text NOT IN """ + enum_values
await execute_raw_with_schema(
query,

View File

@@ -1,6 +1,6 @@
"""LLM Registry - Dynamic model management system."""
from .model import ModelMetadata
from backend.blocks.llm import ModelMetadata
from .registry import (
RegistryModel,
RegistryModelCost,

View File

@@ -1,9 +0,0 @@
"""Type definitions for LLM model metadata.
Re-exports ModelMetadata from blocks.llm to avoid type collision.
In PR #5 (block integration), this will become the canonical location.
"""
from backend.blocks.llm import ModelMetadata
__all__ = ["ModelMetadata"]

View File

@@ -9,7 +9,7 @@ from typing import Any
import prisma.models
from backend.data.llm_registry.model import ModelMetadata
from backend.blocks.llm import ModelMetadata
logger = logging.getLogger(__name__)
@@ -62,6 +62,95 @@ _schema_options: list[dict[str, str]] = []
_lock = asyncio.Lock()
def _record_to_registry_model(record: prisma.models.LlmModel) -> RegistryModel:
"""Transform a raw Prisma LlmModel record into a RegistryModel instance."""
# Parse costs
costs = tuple(
RegistryModelCost(
unit=str(cost.unit), # Convert enum to string
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=dict(cost.metadata or {}),
)
for cost in (record.Costs or [])
)
# Parse creator
creator = None
if record.Creator:
creator = RegistryModelCreator(
id=record.Creator.id,
name=record.Creator.name,
display_name=record.Creator.displayName,
description=record.Creator.description,
website_url=record.Creator.websiteUrl,
logo_url=record.Creator.logoUrl,
)
# Parse capabilities
capabilities = dict(record.capabilities or {})
# Build metadata from record
# Warn if Provider relation is missing (indicates data corruption)
if not record.Provider:
logger.warning(
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
f"falling back to providerId {record.providerId}"
)
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
provider_display = (
record.Provider.displayName
if record.Provider
else record.providerId
)
# Extract creator name (fallback to "Unknown" if no creator)
creator_name = (
record.Creator.displayName if record.Creator else "Unknown"
)
# Price tier defaults to 1 if not set
if record.priceTier not in (1, 2, 3):
logger.warning(
f"LlmModel {record.slug} has out-of-range priceTier={record.priceTier}, "
"defaulting to 1"
)
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=(
record.maxOutputTokens
if record.maxOutputTokens is not None
else record.contextWindow
),
display_name=record.displayName,
provider_name=provider_display,
creator_name=creator_name,
price_tier=price_tier,
)
return RegistryModel(
slug=record.slug,
display_name=record.displayName,
description=record.description,
metadata=metadata,
capabilities=capabilities,
extra_metadata=dict(record.metadata or {}),
provider_display_name=provider_display,
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
costs=costs,
creator=creator,
)
async def refresh_llm_registry() -> None:
"""
Refresh the LLM registry from the database.
@@ -83,87 +172,7 @@ async def refresh_llm_registry() -> None:
# Build model instances
new_models: dict[str, RegistryModel] = {}
for record in records:
# Parse costs
costs = tuple(
RegistryModelCost(
unit=str(cost.unit), # Convert enum to string
credit_cost=cost.creditCost,
credential_provider=cost.credentialProvider,
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=dict(cost.metadata or {}),
)
for cost in (record.Costs or [])
)
# Parse creator
creator = None
if record.Creator:
creator = RegistryModelCreator(
id=record.Creator.id,
name=record.Creator.name,
display_name=record.Creator.displayName,
description=record.Creator.description,
website_url=record.Creator.websiteUrl,
logo_url=record.Creator.logoUrl,
)
# Parse capabilities
capabilities = dict(record.capabilities or {})
# Build metadata from record
# Warn if Provider relation is missing (indicates data corruption)
if not record.Provider:
logger.warning(
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
f"falling back to providerId {record.providerId}"
)
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
provider_display = (
record.Provider.displayName
if record.Provider
else record.providerId
)
# Extract creator name (fallback to "Unknown" if no creator)
creator_name = (
record.Creator.displayName if record.Creator else "Unknown"
)
# Price tier defaults to 1 if not set
price_tier = record.priceTier if record.priceTier in (1, 2, 3) else 1
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=(
record.maxOutputTokens
if record.maxOutputTokens is not None
else record.contextWindow
),
display_name=record.displayName,
provider_name=provider_display,
creator_name=creator_name,
price_tier=price_tier,
)
# Create model instance
model = RegistryModel(
slug=record.slug,
display_name=record.displayName,
description=record.description,
metadata=metadata,
capabilities=capabilities,
extra_metadata=dict(record.metadata or {}),
provider_display_name=provider_display,
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
costs=costs,
creator=creator,
)
model = _record_to_registry_model(record)
new_models[record.slug] = model
# Atomic swap
@@ -213,7 +222,7 @@ def get_enabled_models() -> list[RegistryModel]:
def get_schema_options() -> list[dict[str, str]]:
"""Get schema options for model selection dropdown (enabled models only)."""
return _schema_options
return list(_schema_options)
def get_default_model_slug() -> str | None: