mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(platform): Add LLM registry core - DB layer + in-memory cache
Implements the registry core for dynamic LLM model management: **DB Layer:** - Fetch models with provider, costs, and creator relations - Prisma query with includes for related data - Convert DB records to typed dataclasses **In-memory Cache:** - Global dict for fast model lookups - Atomic cache refresh with lock protection - Schema options generation for UI dropdowns **Public API:** - get_model(slug) - lookup by slug - get_all_models() - all models (including disabled) - get_enabled_models() - enabled models only - get_schema_options() - UI dropdown data - get_default_model_slug() - recommended or first enabled - refresh_llm_registry() - manual refresh trigger **Integration:** - Refresh at API startup (before block init) - Graceful fallback if registry unavailable - Enables blocks to consume registry data **Models:** - RegistryModel - full model with metadata - RegistryModelCost - pricing configuration - RegistryModelCreator - model creator info - ModelMetadata - context window, capabilities **Next PRs:** - PR #3: Public read API (GET endpoints) - PR #4: Admin write API (POST/PATCH/DELETE) - PR #5: Block integration (update LLM block) - PR #6: Redis cache (solve thundering herd) Lines: ~230 (registry.py ~210, __init__.py ~30, model.py from draft) Files: 4 (3 new, 1 modified)
This commit is contained in:
@@ -117,6 +117,14 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
try:
|
||||
from backend.data.llm_registry import refresh_llm_registry
|
||||
|
||||
await refresh_llm_registry()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to refresh LLM registry at startup: {e}")
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"""LLM Registry - Dynamic model management system."""
|
||||
|
||||
from .model import ModelMetadata
|
||||
from .registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_all_models,
|
||||
get_default_model_slug,
|
||||
get_enabled_models,
|
||||
get_model,
|
||||
get_schema_options,
|
||||
refresh_llm_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
"ModelMetadata",
|
||||
"RegistryModel",
|
||||
"RegistryModelCost",
|
||||
"RegistryModelCreator",
|
||||
# Functions
|
||||
"refresh_llm_registry",
|
||||
"get_model",
|
||||
"get_all_models",
|
||||
"get_enabled_models",
|
||||
"get_schema_options",
|
||||
"get_default_model_slug",
|
||||
"get_all_model_slugs_for_validation",
|
||||
]
|
||||
25
autogpt_platform/backend/backend/data/llm_registry/model.py
Normal file
25
autogpt_platform/backend/backend/data/llm_registry/model.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Type definitions for LLM model metadata."""
|
||||
|
||||
from typing import Literal, NamedTuple
|
||||
|
||||
|
||||
class ModelMetadata(NamedTuple):
|
||||
"""Metadata for an LLM model.
|
||||
|
||||
Attributes:
|
||||
provider: The provider identifier (e.g., "openai", "anthropic")
|
||||
context_window: Maximum context window size in tokens
|
||||
max_output_tokens: Maximum output tokens (None if unlimited)
|
||||
display_name: Human-readable name for the model
|
||||
provider_name: Human-readable provider name (e.g., "OpenAI", "Anthropic")
|
||||
creator_name: Name of the organization that created the model
|
||||
price_tier: Relative cost tier (1=cheapest, 2=medium, 3=expensive)
|
||||
"""
|
||||
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int | None
|
||||
display_name: str
|
||||
provider_name: str
|
||||
creator_name: str
|
||||
price_tier: Literal[1, 2, 3]
|
||||
227
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
227
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Core LLM registry implementation for managing models dynamically."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import prisma.models
|
||||
|
||||
from backend.data.llm_registry.model import ModelMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _json_to_dict(value: Any) -> dict[str, Any]:
|
||||
"""Convert Prisma Json type to dict, with fallback to empty dict."""
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
# Prisma Json type should always be a dict at runtime
|
||||
return dict(value) if value else {}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCost:
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: str | None
|
||||
credential_type: str | None
|
||||
currency: str | None
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCreator:
|
||||
"""Creator information for an LLM model."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
website_url: str | None
|
||||
logo_url: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModel:
|
||||
"""Represents a model in the LLM registry."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
metadata: ModelMetadata
|
||||
capabilities: dict[str, Any]
|
||||
extra_metadata: dict[str, Any]
|
||||
provider_display_name: str
|
||||
is_enabled: bool
|
||||
is_recommended: bool = False
|
||||
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
|
||||
creator: RegistryModelCreator | None = None
|
||||
|
||||
|
||||
# In-memory cache (will be replaced with Redis in PR #6)
|
||||
_dynamic_models: dict[str, RegistryModel] = {}
|
||||
_schema_options: list[dict[str, str]] = []
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""
|
||||
Refresh the LLM registry from the database.
|
||||
|
||||
Fetches all models with their costs, providers, and creators,
|
||||
then updates the in-memory cache.
|
||||
"""
|
||||
async with _lock:
|
||||
try:
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
include={
|
||||
"Provider": True,
|
||||
"Costs": True,
|
||||
"Creator": True,
|
||||
}
|
||||
)
|
||||
logger.info(f"Fetched {len(records)} LLM models from database")
|
||||
|
||||
# Build model instances
|
||||
new_models: dict[str, RegistryModel] = {}
|
||||
for record in records:
|
||||
# Parse costs
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=_json_to_dict(cost.metadata),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
# Parse creator
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
# Parse capabilities
|
||||
capabilities = _json_to_dict(record.capabilities)
|
||||
|
||||
# Build metadata from record
|
||||
provider_name = (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
)
|
||||
provider_display = (
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId
|
||||
)
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=record.maxOutputTokens or record.contextWindow,
|
||||
supports_vision=capabilities.get("supportsVision", False),
|
||||
)
|
||||
|
||||
# Create model instance
|
||||
model = RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=capabilities,
|
||||
extra_metadata=_json_to_dict(record.metadata),
|
||||
provider_display_name=provider_display,
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
)
|
||||
new_models[record.slug] = model
|
||||
|
||||
# Atomic swap
|
||||
global _dynamic_models, _schema_options
|
||||
_dynamic_models = new_models
|
||||
_schema_options = _build_schema_options()
|
||||
|
||||
logger.info(
|
||||
f"LLM registry refreshed: {len(_dynamic_models)} models, "
|
||||
f"{len(_schema_options)} schema options"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh LLM registry: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def _build_schema_options() -> list[dict[str, str]]:
|
||||
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||
options: list[dict[str, str]] = []
|
||||
# Only include enabled models in the dropdown options
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
|
||||
if model.is_enabled:
|
||||
options.append(
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
)
|
||||
return options
|
||||
|
||||
|
||||
def get_model(slug: str) -> RegistryModel | None:
|
||||
"""Get a model by slug from the registry."""
|
||||
return _dynamic_models.get(slug)
|
||||
|
||||
|
||||
def get_all_models() -> list[RegistryModel]:
|
||||
"""Get all models from the registry (including disabled)."""
|
||||
return list(_dynamic_models.values())
|
||||
|
||||
|
||||
def get_enabled_models() -> list[RegistryModel]:
|
||||
"""Get only enabled models from the registry."""
|
||||
return [model for model in _dynamic_models.values() if model.is_enabled]
|
||||
|
||||
|
||||
def get_schema_options() -> list[dict[str, str]]:
|
||||
"""Get schema options for model selection dropdown (enabled models only)."""
|
||||
return _schema_options
|
||||
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
"""Get the default model slug (first recommended, or first enabled)."""
|
||||
# Prefer recommended models
|
||||
for model in _dynamic_models.values():
|
||||
if model.is_recommended and model.is_enabled:
|
||||
return model.slug
|
||||
|
||||
# Fallback to first enabled model
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name):
|
||||
if model.is_enabled:
|
||||
return model.slug
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> list[str]:
|
||||
"""
|
||||
Get all model slugs for validation (enables migrate_llm_models to work).
|
||||
Returns slugs for enabled models only.
|
||||
"""
|
||||
return [model.slug for model in _dynamic_models.values() if model.is_enabled]
|
||||
Reference in New Issue
Block a user