refactor: address CodeRabbit/Majdyz review feedback

- Fix ModelMetadata duplicate type collision by importing from blocks.llm
- Remove _json_to_dict helper, use dict() inline
- Add warning when Provider relation is missing (data corruption indicator)
- Optimize get_default_model_slug with next() (single sort pass)
- Optimize _build_schema_options to use list comprehension
- Move llm_registry import to top-level in rest_api.py
- Ensure max_output_tokens falls back to context_window when null

All critical and quick-win issues addressed.
This commit is contained in:
Bentlybro
2026-03-11 13:18:55 +00:00
parent c64246be87
commit 05fa10925c
3 changed files with 45 additions and 62 deletions

View File

@@ -37,6 +37,7 @@ import backend.api.features.workspace.routes as workspace_routes
import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.llm_registry
import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
@@ -121,9 +122,7 @@ async def lifespan_context(app: fastapi.FastAPI):
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
# When block integration lands, this should fail hard or skip block initialization
try:
from backend.data.llm_registry import refresh_llm_registry
await refresh_llm_registry()
await backend.data.llm_registry.refresh_llm_registry()
logger.info("LLM registry refreshed successfully at startup")
except Exception as e:
logger.warning(

View File

@@ -1,25 +1,9 @@
"""Type definitions for LLM model metadata."""
"""Type definitions for LLM model metadata.
from typing import Literal, NamedTuple
Re-exports ModelMetadata from blocks.llm to avoid type collision.
In PR #5 (block integration), this will become the canonical location.
"""
from backend.blocks.llm import ModelMetadata
class ModelMetadata(NamedTuple):
"""Metadata for an LLM model.
Attributes:
provider: The provider identifier (e.g., "openai", "anthropic")
context_window: Maximum context window size in tokens
max_output_tokens: Maximum output tokens (None if unlimited)
display_name: Human-readable name for the model
provider_name: Human-readable provider name (e.g., "OpenAI", "Anthropic")
creator_name: Name of the organization that created the model
price_tier: Relative cost tier (1=cheapest, 2=medium, 3=expensive)
"""
provider: str
context_window: int
max_output_tokens: int | None
display_name: str
provider_name: str
creator_name: str
price_tier: Literal[1, 2, 3]
__all__ = ["ModelMetadata"]

View File

@@ -14,16 +14,6 @@ from backend.data.llm_registry.model import ModelMetadata
logger = logging.getLogger(__name__)
def _json_to_dict(value: Any) -> dict[str, Any]:
"""Convert Prisma Json type to dict, with fallback to empty dict."""
if value is None:
return {}
if isinstance(value, dict):
return value
# Prisma Json type should always be a dict at runtime
return dict(value) if value else {}
@dataclass(frozen=True)
class RegistryModelCost:
"""Cost configuration for an LLM model."""
@@ -102,7 +92,7 @@ async def refresh_llm_registry() -> None:
credential_id=cost.credentialId,
credential_type=cost.credentialType,
currency=cost.currency,
metadata=_json_to_dict(cost.metadata),
metadata=dict(cost.metadata or {}),
)
for cost in (record.Costs or [])
)
@@ -120,9 +110,15 @@ async def refresh_llm_registry() -> None:
)
# Parse capabilities
capabilities = _json_to_dict(record.capabilities)
capabilities = dict(record.capabilities or {})
# Build metadata from record
# Warn if Provider relation is missing (indicates data corruption)
if not record.Provider:
logger.warning(
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
f"falling back to providerId {record.providerId}"
)
provider_name = (
record.Provider.name if record.Provider else record.providerId
)
@@ -143,7 +139,11 @@ async def refresh_llm_registry() -> None:
metadata = ModelMetadata(
provider=provider_name,
context_window=record.contextWindow,
max_output_tokens=record.maxOutputTokens,
max_output_tokens=(
record.maxOutputTokens
if record.maxOutputTokens is not None
else record.contextWindow
),
display_name=record.displayName,
provider_name=provider_display,
creator_name=creator_name,
@@ -157,7 +157,7 @@ async def refresh_llm_registry() -> None:
description=record.description,
metadata=metadata,
capabilities=capabilities,
extra_metadata=_json_to_dict(record.metadata),
extra_metadata=dict(record.metadata or {}),
provider_display_name=provider_display,
is_enabled=record.isEnabled,
is_recommended=record.isRecommended,
@@ -182,19 +182,18 @@ async def refresh_llm_registry() -> None:
def _build_schema_options() -> list[dict[str, str]]:
"""Build schema options for model selection dropdown. Only includes enabled models."""
options: list[dict[str, str]] = []
# Only include enabled models in the dropdown options
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
if model.is_enabled:
options.append(
{
"label": model.display_name,
"value": model.slug,
"group": model.metadata.provider,
"description": model.description or "",
}
)
return options
return [
{
"label": model.display_name,
"value": model.slug,
"group": model.metadata.provider,
"description": model.description or "",
}
for model in sorted(
_dynamic_models.values(), key=lambda m: m.display_name.lower()
)
if model.is_enabled
]
def get_model(slug: str) -> RegistryModel | None:
@@ -219,17 +218,18 @@ def get_schema_options() -> list[dict[str, str]]:
def get_default_model_slug() -> str | None:
"""Get the default model slug (first recommended, or first enabled)."""
# Prefer recommended models (sorted for deterministic selection)
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name):
if model.is_recommended and model.is_enabled:
return model.slug
# Sort once and use next() to short-circuit on first match
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
# Prefer recommended models
recommended = next(
(m.slug for m in models if m.is_recommended and m.is_enabled), None
)
if recommended:
return recommended
# Fallback to first enabled model
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name):
if model.is_enabled:
return model.slug
return None
return next((m.slug for m in models if m.is_enabled), None)
def get_all_model_slugs_for_validation() -> list[str]: