mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
refactor: address CodeRabbit/Majdyz review feedback
- Fix ModelMetadata duplicate type collision by importing from blocks.llm - Remove _json_to_dict helper, use dict() inline - Add warning when Provider relation is missing (data corruption indicator) - Optimize get_default_model_slug with next() (single sort pass) - Optimize _build_schema_options to use list comprehension - Move llm_registry import to top-level in rest_api.py - Ensure max_output_tokens falls back to context_window when null All critical and quick-win issues addressed.
This commit is contained in:
@@ -37,6 +37,7 @@ import backend.api.features.workspace.routes as workspace_routes
|
||||
import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.llm_registry
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
@@ -121,9 +122,7 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
# Note: Graceful fallback for now since no blocks consume registry yet (comes in PR #5)
|
||||
# When block integration lands, this should fail hard or skip block initialization
|
||||
try:
|
||||
from backend.data.llm_registry import refresh_llm_registry
|
||||
|
||||
await refresh_llm_registry()
|
||||
await backend.data.llm_registry.refresh_llm_registry()
|
||||
logger.info("LLM registry refreshed successfully at startup")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
|
||||
@@ -1,25 +1,9 @@
|
||||
"""Type definitions for LLM model metadata."""
|
||||
"""Type definitions for LLM model metadata.
|
||||
|
||||
from typing import Literal, NamedTuple
|
||||
Re-exports ModelMetadata from blocks.llm to avoid type collision.
|
||||
In PR #5 (block integration), this will become the canonical location.
|
||||
"""
|
||||
|
||||
from backend.blocks.llm import ModelMetadata
|
||||
|
||||
class ModelMetadata(NamedTuple):
|
||||
"""Metadata for an LLM model.
|
||||
|
||||
Attributes:
|
||||
provider: The provider identifier (e.g., "openai", "anthropic")
|
||||
context_window: Maximum context window size in tokens
|
||||
max_output_tokens: Maximum output tokens (None if unlimited)
|
||||
display_name: Human-readable name for the model
|
||||
provider_name: Human-readable provider name (e.g., "OpenAI", "Anthropic")
|
||||
creator_name: Name of the organization that created the model
|
||||
price_tier: Relative cost tier (1=cheapest, 2=medium, 3=expensive)
|
||||
"""
|
||||
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int | None
|
||||
display_name: str
|
||||
provider_name: str
|
||||
creator_name: str
|
||||
price_tier: Literal[1, 2, 3]
|
||||
__all__ = ["ModelMetadata"]
|
||||
|
||||
@@ -14,16 +14,6 @@ from backend.data.llm_registry.model import ModelMetadata
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _json_to_dict(value: Any) -> dict[str, Any]:
|
||||
"""Convert Prisma Json type to dict, with fallback to empty dict."""
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
# Prisma Json type should always be a dict at runtime
|
||||
return dict(value) if value else {}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCost:
|
||||
"""Cost configuration for an LLM model."""
|
||||
@@ -102,7 +92,7 @@ async def refresh_llm_registry() -> None:
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=_json_to_dict(cost.metadata),
|
||||
metadata=dict(cost.metadata or {}),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
@@ -120,9 +110,15 @@ async def refresh_llm_registry() -> None:
|
||||
)
|
||||
|
||||
# Parse capabilities
|
||||
capabilities = _json_to_dict(record.capabilities)
|
||||
capabilities = dict(record.capabilities or {})
|
||||
|
||||
# Build metadata from record
|
||||
# Warn if Provider relation is missing (indicates data corruption)
|
||||
if not record.Provider:
|
||||
logger.warning(
|
||||
f"LlmModel {record.slug} has no Provider despite NOT NULL FK - "
|
||||
f"falling back to providerId {record.providerId}"
|
||||
)
|
||||
provider_name = (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
)
|
||||
@@ -143,7 +139,11 @@ async def refresh_llm_registry() -> None:
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=record.maxOutputTokens,
|
||||
max_output_tokens=(
|
||||
record.maxOutputTokens
|
||||
if record.maxOutputTokens is not None
|
||||
else record.contextWindow
|
||||
),
|
||||
display_name=record.displayName,
|
||||
provider_name=provider_display,
|
||||
creator_name=creator_name,
|
||||
@@ -157,7 +157,7 @@ async def refresh_llm_registry() -> None:
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=capabilities,
|
||||
extra_metadata=_json_to_dict(record.metadata),
|
||||
extra_metadata=dict(record.metadata or {}),
|
||||
provider_display_name=provider_display,
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
@@ -182,19 +182,18 @@ async def refresh_llm_registry() -> None:
|
||||
|
||||
def _build_schema_options() -> list[dict[str, str]]:
|
||||
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||
options: list[dict[str, str]] = []
|
||||
# Only include enabled models in the dropdown options
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
|
||||
if model.is_enabled:
|
||||
options.append(
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
)
|
||||
return options
|
||||
return [
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
for model in sorted(
|
||||
_dynamic_models.values(), key=lambda m: m.display_name.lower()
|
||||
)
|
||||
if model.is_enabled
|
||||
]
|
||||
|
||||
|
||||
def get_model(slug: str) -> RegistryModel | None:
|
||||
@@ -219,17 +218,18 @@ def get_schema_options() -> list[dict[str, str]]:
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
"""Get the default model slug (first recommended, or first enabled)."""
|
||||
# Prefer recommended models (sorted for deterministic selection)
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name):
|
||||
if model.is_recommended and model.is_enabled:
|
||||
return model.slug
|
||||
|
||||
# Sort once and use next() to short-circuit on first match
|
||||
models = sorted(_dynamic_models.values(), key=lambda m: m.display_name)
|
||||
|
||||
# Prefer recommended models
|
||||
recommended = next(
|
||||
(m.slug for m in models if m.is_recommended and m.is_enabled), None
|
||||
)
|
||||
if recommended:
|
||||
return recommended
|
||||
|
||||
# Fallback to first enabled model
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name):
|
||||
if model.is_enabled:
|
||||
return model.slug
|
||||
|
||||
return None
|
||||
return next((m.slug for m in models if m.is_enabled), None)
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> list[str]:
|
||||
|
||||
Reference in New Issue
Block a user